{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/sharedata/mdy/miniforge/envs/cuda128/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2025-04-27 12:17:55,119] [INFO] [real_accelerator.py:239:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/sharedata/mdy/miniforge/envs/cuda128/compiler_compat/ld: cannot find -laio: No such file or directory\n",
      "collect2: error: ld returned 1 exit status\n",
      "/sharedata/mdy/miniforge/envs/cuda128/compiler_compat/ld: cannot find -laio: No such file or directory\n",
      "collect2: error: ld returned 1 exit status\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import triton\n",
    "import triton.language as tl\n",
    "from copy import deepcopy\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "import time\n",
    "os.environ['TRITON_PRINT_AUTOTUNING'] = '1'\n",
    "from transformers import Qwen2ForCausalLM, AutoModelForCausalLM, AutoConfig\n",
    "from trl import GRPOTrainer\n",
    "import torch.nn.functional as F\n",
    "from triton_grpo_loss.core import triton_grpo_loss\n",
    "\n",
    "\n",
    "def compare(x, y):\n",
    "    if x is None or y is None:\n",
    "        return\n",
    "    if any([x.dtype == torch.float32, y.dtype==torch.float32]):\n",
    "        x,y = x.float(), y.float()\n",
    "    diff = (x-y).abs()\n",
    "    diff = diff / (torch.max(x.abs(), y.abs()) + 1e-5)\n",
    "    print(f\"最大差异: {diff.max().item()}, 平均差异: {diff.mean().item()}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# logp"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## torch code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 对原始函数做了简单修改\n",
    "def selective_log_softmax(logits, input_ids, temperature=0.9):\n",
    "    logits = logits[:, :-1, :]  # (B, L-1, V), exclude the last logit: it corresponds to the next token pred\n",
    "    logits_to_keep = logits.size(1)\n",
    "    index = input_ids[:, -logits_to_keep:]\n",
    "    logits = logits[:, -logits_to_keep:]\n",
    "    logits = logits / temperature\n",
    "\n",
    "    if logits.dtype in [torch.float32, torch.float64]:\n",
    "        selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)\n",
    "        # loop to reduce peak mem consumption\n",
    "        logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])\n",
    "        per_token_logps = selected_logits - logsumexp_values  # log_softmax(x_i) = x_i - logsumexp(x)\n",
    "    else:\n",
    "        # logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach\n",
    "        per_token_logps = []\n",
    "        for row_logits, row_labels in zip(logits, index):  # loop to reduce peak mem consumption\n",
    "            row_logps = F.log_softmax(row_logits, dim=-1)\n",
    "            row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)\n",
    "            per_token_logps.append(row_per_token_logps)\n",
    "        per_token_logps = torch.stack(per_token_logps)\n",
    "    return per_token_logps"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## triton code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# @triton.autotune([triton.Config({\"BLOCK_N\":BLOCK_N}, num_stages=ns, num_warps=nw)\n",
    "#                   for BLOCK_N in [2048, 4096, 8192]\n",
    "#                   for ns in [1, 2, 4]\n",
    "#                   for nw in [1, 2, 4, 8, 16]],\n",
    "#                   key=['N'])\n",
    "@triton.jit\n",
    "def _selective_log_softmax_kernel(LOGITS,\n",
    "                                  INPUT_IDS,\n",
    "                                  LOG_P,\n",
    "                                  MASK,\n",
    "                                  TEMPERATURE,\n",
    "                                  stride_input_ids_b,\n",
    "                                  L: tl.constexpr,\n",
    "                                  N: tl.constexpr,\n",
    "                                  BLOCK_N:tl.constexpr=4096):\n",
    "    off_b = tl.program_id(0).cast(tl.int64)\n",
    "    off_l = tl.program_id(1).cast(tl.int64)\n",
    "\n",
    "    LOGITS += off_b * (L+1) * N + off_l * N\n",
    "    INPUT_IDS += off_b * stride_input_ids_b + off_l\n",
    "    LOG_P += off_b * L + off_l\n",
    "\n",
    "    \n",
    "    if MASK is not None:\n",
    "        MASK += off_b * stride_input_ids_b + off_l\n",
    "        not_skip = tl.load(MASK)\n",
    "        if not_skip == 0:\n",
    "            return\n",
    "\n",
    "    m_i = float('-inf')\n",
    "    l_i = 0. \n",
    "    for start in range(0, N, BLOCK_N):\n",
    "        cols = start + tl.arange(0, BLOCK_N)\n",
    "        logits = tl.load(LOGITS + cols, mask=cols < N, other=float('-inf')).to(tl.float32) / TEMPERATURE\n",
    "        new_m_i = tl.maximum(m_i, tl.max(logits))\n",
    "        alpha = tl.exp(m_i - new_m_i)\n",
    "        l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))\n",
    "        m_i = new_m_i\n",
    "    lse = m_i + tl.log(l_i)\n",
    "\n",
    "    ids = tl.load(INPUT_IDS)\n",
    "    x = tl.load(LOGITS + ids).to(tl.float32) / TEMPERATURE\n",
    "    logp = x - lse\n",
    "    tl.store(LOG_P, logp)\n",
    "    \n",
    "\n",
    "# 用于计算old_logp和ref_logp，不需要梯度，只写前向即可\n",
    "def fused_selective_log_softmax(logits:torch.Tensor, input_ids:torch.Tensor, temperature:float=0.9, mask=None):\n",
    "    assert logits.is_contiguous()\n",
    "    B, L_ADD_1, N = logits.shape\n",
    "    L = L_ADD_1 - 1\n",
    "    input_ids = input_ids[:, -L:]\n",
    "    if mask is not None:\n",
    "        mask = mask[:, -L:]\n",
    "    log_p = torch.zeros(B, L, dtype=torch.float32, device=logits.device)\n",
    "    kwargs = {\"BLOCK_N\":2048, \"num_stages\":4, \"num_warps\":1}\n",
    "    _selective_log_softmax_kernel[(B, L)](logits,\n",
    "                                          input_ids,\n",
    "                                          log_p,\n",
    "                                          mask,\n",
    "                                          temperature,\n",
    "                                          input_ids.stride(0),\n",
    "                                          L,\n",
    "                                          N,\n",
    "                                          **kwargs\n",
    "                                          )\n",
    "    return log_p\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 精度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab_size = 128 * 1000\n",
    "B, L = 8, 1024\n",
    "device = \"cuda\"\n",
    "dtype = torch.bfloat16\n",
    "logits = torch.randn(B, L + 1, vocab_size, device=device, dtype=dtype)\n",
    "input_ids = torch.randint(0, vocab_size-1, (B, L + 100), dtype=torch.int64, device=device)\n",
    "mask = torch.ones(B, L+100, dtype=torch.int64, device=device)\n",
    "# mask[:, -200:] = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch-bf16 vs torch-fp32\n",
      "最大差异: 0.005191626027226448, 平均差异: 0.001272668712772429\n",
      "triton-bf16 vs torch-fp32\n",
      "最大差异: 1.84596260055514e-07, 平均差异: 1.1364876151276349e-08\n"
     ]
    }
   ],
   "source": [
    "y1 = selective_log_softmax(logits, input_ids)\n",
    "y2 = fused_selective_log_softmax(logits, input_ids, mask=None)\n",
    "gold_y = selective_log_softmax(logits.float(), input_ids)\n",
    "print(\"torch-bf16 vs torch-fp32\")\n",
    "compare(y1, gold_y)\n",
    "print(\"triton-bf16 vs torch-fp32\")\n",
    "compare(y2, gold_y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 速度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5387963652610779\n",
      "4.668179988861084\n"
     ]
    }
   ],
   "source": [
    "print(triton.testing.do_bench(lambda:fused_selective_log_softmax(logits, input_ids, mask=mask)))\n",
    "print(triton.testing.do_bench(lambda:selective_log_softmax(logits, input_ids)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAk0AAAGxCAYAAAB/QoKnAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAU8BJREFUeJzt3Xd4FGXj9fHvpiekEiAk0nsvAiJFQAmEImBvqCBIB6VJkyKggKABBAQbRR/1UX4qIF16kSIovSNSpElLCCF17/cPXvIQaUsIzG5yPte1l9md2d1zR9g9zNwzYzPGGERERETkltysDiAiIiLiClSaRERERByg0iQiIiLiAJUmEREREQeoNImIiIg4QKVJRERExAEqTSIiIiIOUGkSERERcYCH1QHuNbvdzvHjxwkICMBms1kdR0RERBxgjOHixYtERETg5uYc23iyfGk6fvw4+fPntzqGiIiIZMDRo0fJly+f1TGAbFCaAgICgCu/9MDAQIvTiIiIiCNiY2PJnz9/2ve4M8jypenqLrnAwECVJhERERfjTFNrnGMnoYiIiIiTU2kSERERcYBKk4iIiIgDsvycJkelpqaSnJxsdYwsw9PTE3d3d6tjiIiIZJpsX5qMMZw8eZILFy5YHSXLCQ4OJm/evE41iU9ERCSjsn1pulqY8uTJg5+fn77gM4Exhvj4eE6fPg1AeHi4xYlERETuXrYuTampqWmFKTQ01Oo4WYqvry8Ap0+fJk+ePNpVJyIiLi9bTwS/OofJz8/P4iRZ09Xfq+aKiYhIVpCtS9NV2iV3b+j3KiIiWYlKk4iIiIgDVJqyuHfeeYdKlSpZHUNERMTlqTS5qHr16tG9e/fbrte7d2+WLl2adr9169Y88cQT9y6YiIhIFqXSlEUZY0hJScHf319HBoqIyD1xKekS49ePx27sVke5L1SaXFDr1q1ZuXIl48ePx2azYbPZmD59OjabjQULFlClShW8vb1Zs2ZNut1z77zzDjNmzGD27Nlpz1uxYgUA27dv57HHHsPX15fQ0FDat29PXFxcuvd84okn+OCDDwgPDyc0NJQuXbroyDgRkWxq2aFlVJhSge6LujP5t8lWx7kvsvV5mv7NGIiPt+a9/fzA0YPNxo8fz759+yhXrhzDhg0DYOfOnQD069ePDz74gCJFihASEpJWiuDKrrrdu3cTGxvLtGnTAMiZMyeXLl0iKiqKGjVq8Ntvv3H69Glef/11unbtyvTp09Oev3z5csLDw1m+fDkHDhzg+eefp1KlSrRr1y5TfgciIuL8YhNj6fNLHz7Z/AkA+QPzUzy0uMWp7g+VpmvEx4O/vzXvHRcHOXI4tm5QUBBeXl74+fmRN29eAPbs2QPAsGHDaNCgwQ2f5+/vj6+vL4mJiWnPA5gxYwYJCQl8+eWX5Pj/ISZOnEizZs14//33CQsLAyAkJISJEyfi7u5OqVKlaNq0KUuXLlVpEhHJJhbsX0D7ue05FnsMgE5VOzEqchSB3oEWJ7s/VJqymKpVq97xc3bv3k3FihXTChNArVq1sNvt7N27N600lS1bNt2ZvcPDw9m+ffvdhxYREad27vI5eizqwZdbvwSgSEgRvmj+BfUK1bM22H2m0nQNP78rW3yseu/MkMPRzVUZ4Onpme6+zWbDbs8ek/9ERLKrn3b/RKd5nTh16RQ2bHR/uDvDHx1ODq97933jrFSarmGzOb6LzGpeXl6kpqZmyvNKly7N9OnTuXTpUlrpWrt2LW5ubpQsWTJT8oqIiGs5fek03RZ04/ud3wNQKlcppjafSo38NSxOZh0dPeeiChUqxIYNG/jrr784c+aMw1t8ChUqxLZt29i7dy9nzpwhOTmZli1b4uPjQ6tWrdixYwfLly+nW7duvPLKK2m75kREJHswxvDt9m8pM6kM3+/8HnebOwNqD+CPDn9k68IEKk0uq3fv3ri7u1OmTBly587NkSNHHHpeu3btKFmyJFWrViV37tysXbsWPz8/Fi1axLlz56hWrRrPPPMM9evXZ+LEifd4FCIi4kyOXzzOE989wUs/vsTZy2epEFaBje028l799/Dx8LE6nuVsxhhjdYh7KTY2lqCgIGJiYggMTD+7PyEhgUOHDlG4cGF8fPSHIbPp9ysi4hqMMUzbMo2ei3oSkxiDp5sng+oMom/tvni5e1mS6Vbf31bRnCYREZFs7PCFw7T7uR2//PkLANUiqjG1xVTK5SlncTLno9IkIiKSDdmNnSmbptB3SV/ikuLw8fBhWL1h9KjRAw831YMb0W9FREQkm9l/dj+v//w6qw6vAqB2gdp80fwLSoSWsDiZc1NpEhERySZS7amMWz+OgcsHkpCSQA7PHIyKHEXnap1xs+nYsNtRaRIREckGdv2zizaz27Dh7w0A1C9cn8+afUbhkMIWJ3MdKk0iIiJZWHJqMqPXjmbYqmEkpSYR6B3Ihw0/pG3lttgcvVK8ACpNIiIiWdaWk1t4bfZrbDm5BYCmxZsy5fEp5AvMZ20wF6XSJCIiksUkpiQyfNVw3l/7Pin2FHL65uSjRh/xUvmXtHXpLqg0yS21bt2aCxcuMGvWLKujiIiIAzYc20CbOW3Y9c8uAJ4p8wwTG08kzF+XxbpbmirvourVq0f37t2tjiEiIk4iPjme3ot7U3NqTXb9s4s8OfIw89mZzHx2pgpTJtGWpmwsKSkJLy9rTo8vIiKZZ9XhVbSd05YD5w4A8HKFlxkXNY5Qv1CLk2Ut2tLkglq3bs3KlSsZP348NpsNm83GX3/9xcqVK3nooYfw9vYmPDycfv36kZKSkva8evXq0bVrV7p3706uXLmIiooCYOfOnTz++OMEBgYSEBDAI488wsGDB9O95wcffEB4eDihoaF06dKF5OTk+zpmERG53sXEi3SZ14W60+ty4NwBHgh4gLkvzuWrJ79SYboHtKXJBY0fP559+/ZRrlw5hg0bBkBqaipNmjShdevWfPnll+zZs4d27drh4+PDO++8k/bcGTNm0KlTJ9auXQvA33//TZ06dahXrx7Lli0jMDCQtWvXpitby5cvJzw8nOXLl3PgwAGef/55KlWqRLt27e7ruEVE5H8WH1xMu5/bcSTmCADtHmzHmAZjCPIJsjhZ1qXSdA1jDPHJ8Za8t5+nn8NHNAQFBeHl5YWfnx958+YF4O233yZ//vxMnDgRm81GqVKlOH78OH379mXw4MG4uV3ZqFi8eHFGjx6d9loDBgwgKCiI//73v3h6egJQokT60+iHhIQwceJE3N3dKVWqFE2bNmXp0qUqTSIiFriQcIFei3oxdctUAAoFF+KzZp8RWSTS4mRZn0rTNeKT4/Ef6W/Je8f1jyOHV44MP3/37t3UqFEjXfGqVasWcXFxHDt2jAIFCgBQpUqVdM/bsmULjzzySFphupGyZcvi7u6edj88PJzt27dnOKuIiGTMnL1z6DSvE8cvHgeg20PdGFF/BP5e1nx3ZTcqTdlMjhzpi5mvr+9tn/PvQmWz2bDb7ZmaS0REbu5M/BneXPgm32z/BoDiOYvzRfMveKTgIxYny15Umq7h5+lHXP84y977Tnh5eZGampp2v3Tp0vzwww8YY9K2Nq1du5aAgADy5bv5mV8rVKjAjBkzSE5OvuXWJhERuf+MMczcNZOu87vyT/w/uNnc6F2jN+/Uewdfz9v/o1cyl0rTNWw2213tIrufChUqxIYNG/jrr7/w9/enc+fOjBs3jm7dutG1a1f27t3LkCFD6NmzZ9p8phvp2rUrEyZM4IUXXqB///4EBQWxfv16HnroIUqWLHkfRyQiItc6GXeSzvM689OenwAom7ss01pMo9oD1SxOln1ZesqBkSNHUq1aNQICAsiTJw9PPPEEe/fuTbdOvXr10g6rv3rr2LGjRYmdR+/evXF3d6dMmTLkzp2b5ORk5s+fz8aNG6lYsSIdO3akbdu2DBw48JavExoayrJly4iLi6Nu3bpUqVKFzz77TFudREQsYozhy61fUmZSGX7a8xMebh4MrjOYze03qzBZzGaMMVa9eaNGjXjhhReoVq0aKSkpDBgwgB07drBr1660uTf16tWjRIkSaYfWA/j5+REYGOjQe8TGxhIUFERMTMx1z0lISODQoUMULlwYHx+fzBuYAPr9iojcqaMxR+kwtwMLDiwA4MHwB5nafCoV81a0ONn9d6vvb6tYuntu4cKF6e5Pnz6dPHnysHnzZurUqZP2+LWH1ouIiGQ1xhg++/0zei/uzcWki3i5ezG03lB61+yNh5tm0jgLpzojeExMDAA5c+ZM9/jXX39Nrly5KFeuHP379yc+/ubnUkpMTCQ2NjbdTURExFn9ef5P6n9Znw5zO3Ax6SIP53uYLR220K92PxUmJ+M0/zfsdjvdu3enVq1alCtXLu3xl156iYIFCxIREcG2bdvo27cve/fu5ccff7zh64wcOZKhQ4fer9giIiIZkmpPZeLGiQxYNoD45Hh8PXwZUX8E3R7qhrub++1fQO47S+c0XatTp04sWLCANWvW3PIQ+WXLllG/fn0OHDhA0aJFr1uemJhIYmJi2v3Y2Fjy58+vOU0W0O9XROTG9pzZQ9s5bfn16K8A1CtUj8+bfU7RnNd/r2VXmtN0E127dmXu3LmsWrXqloUJoHr16gA3LU3e3t54e3vfk5wiIiJ3I8Wewoe/fsiQFUNITE0kwCuA0Q1G075Ke9xsTjVjRm7A0tJkjKFbt2789NNPrFixgsKFC9/2OVu2bAGuXMojM3NI5tPvVUTkf7ad2kab2W3YfGIzAFFFo/i02acUCCpgcTJxlKWlqUuXLnzzzTfMnj2bgIAATp48CVy5IK2vry8HDx7km2++oUmTJoSGhrJt2zZ69OhBnTp1qFChwl2//9VzEcXHxzt0ORG5M1cn7OucTyKSnSWlJjFi9QjeW/0eKfYUgn2CGRc1jlcrvurwhdrFOVg6p+lmf1imTZtG69atOXr0KC+//DI7duzg0qVL5M+fnyeffJKBAwdmynmaAE6cOMGFCxfIkycPfn5++gOcCYwxxMfHc/r0aYKDgzN1q6CIiCvZdHwTbWa3YfvpKxc5b1GyBZObTiY8QJ+Lt6M5Tf9yu76WP39+Vq5ceU8zXD3/0+nTp+/p+2RHwcHBOr+WiGRLl5MvM3TlUMb8Oga7sZPLLxcTG0/kubLP6R/nLswpJoJbyWazER4eTp48eUhOTrY6Tpbh6emJu7sOmRWR7GftkbW0ndOWvWevXBbshXIv8FGjj8idI7fFyeRuZfvSdJW7u7u+5EVEJMMuJV1iwNIBTNg4AYMh3D+cyU0n06JUC6ujSSZRaRIREblLyw4t4/U5r3PowiEAXqv0Gh82/JAQ3xCLk0lmUmkSERHJoJiEGPr80odPf/8UgPyB+fms2WdEFYuyOJncCypNIiIiGTB//3w6zO3AsdhjAHSq2olRkaMI9HaOI70k86k0iYiI3IFzl8/RfWF3vtr2FQBFQ4ryRfMvqFuorsXJ5F5TaRIREXHQj7t/pPO8zpy6dAobNno83IPhjw3Hz9PP6mhyH6g0iYiI3MapuFN0W9CNmbtmAlA6V2m+aP4FNfLXsDiZ3E8qTSIiIjdhjOHbHd/yxoI3OHv5LO42d/rW6suguoPw8fCxOp7cZypNIiIiN/B37N90mteJn/f9DEDFsIpMbTGVB8MftDiZWEWlSURE5BrGGKb+MZVei3sRkxiDp5sng+sOpm+tvni66wLk2ZlKk4iIyP/314W/aPdzO5b8uQSAahHVmNpiKuXylLM4mTgDlSYREcn27MbO5N8m03dJXy4lX8LHw4fhjw6n+8Pd8XDTV6VcoT8JIiKSre0/u5+2c9qy+shqAGoXqM0Xzb+gRGgJi5OJs1FpEhGRbCnVnsrY9WMZtHwQCSkJ5PDMwajIUXSu1hk3m5vV8cQJqTSJiEi2s/P0TtrMacPGvzcCEFkkks+afUah4ELWBhOnptIkIiLZRnJqMu+vfZ9hK4eRbE8m0DuQ6IbRtKncBpvNZnU8cXIqTSIiki38ceIPXpv9GltPbQXg8RKPM6XpFB4IfMDiZOIqVJpERCRLS0xJZPiq4YxaM4pUk0pO35x81OgjXir/krYuyR1RaRIRkSxr/bH1tJndht1ndgPwbJlnmdB4AmH+YRYnE1ek0iQiIllOfHI8g5YNYuz6sRgMeXLk4eMmH/N0maetjiYuTKVJRESylPn759N1flcOXTgEwCsVXmFs1FhC/UItTiauTqVJRESyhGOxx+i+sDs/7P4BgHyB+ZjSdApNSzS1OJlkFSpNIiLi0lLsKUzYMIHBKwYTlxSHu82dN6u/yTv13iHAO8DqeJKFqDSJiIjLWn9sPR3ndkw7jUCNfDWY3HQyFfNWtDiZZEUqTSIi4nLOXT5H/yX9+ez3zzAYQnxCeD/yfdo+2FaXQJF7RqVJRERchjGGr7Z9Re/Fvfkn/h8AWlVsxZgGY8idI7fF6SSrU2kSERGXsOufXXSe15mVh1cCUCZ3GSY3nUydgnUsTibZhUqTiIg4tfjkeN5d9S5jfh1Dij0FXw9fBtcdTM8aPfFy97I6nmQjKk0iIuK05u6bS7cF3fjrwl/AlevFTWg8gULBhSzNJdmTSpOIiDidozFHeXPhm/y05ycA8gfm56PGH9GiZAtdL04so9IkIiJOIzk1mfEbxvPOine4lHwJd5s7PWv0ZHDdwfh7+VsdT7I5lSYREXEKvx79lY5zO7L99HYAauWvxeSmkykfVt7iZCJXqDSJiIilzsafpd+Sfnz+x+cA5PTNyZgGY2hdqbXOuSRORaVJREQsYYxh+pbp9FnShzPxZwBoU6kN7zd4n1x+uSxOJ3I9lSYREbnvdp7eSad5nVh9ZDUAZXOXZcrjU6hdoLbFyURuTqVJRETum0tJlxi2chjR66NJsafg5+nHO3XfofvD3fF097Q6nsgtqTSJiMh9MWfvHLot6MaRmCMAtCjZgvGNxlMwuKDFyUQco9IkIiL31JGYI7yx4A1m750NQIGgAkxoPIHmJZtbnEzkzqg0iYjIPZGcmszY9WMZunIo8cnxeLh50KtGLwbVGUQOrxxWxxO5YypNIiKS6dYcWUOneZ3YcXoHAI8UeITJTSdTNk9Zi5OJZJxKk4iIZJoz8Wfo80sfpm2ZBkAuv1yMaTCGVhVb6fIn4vJUmkRE5K7ZjZ1pf0yjz5I+nLt8DoDXK7/OqMhRhPqFWpxOJHOoNImIyF3Zfmo7Hed15NejvwJQPk95JjedTK0CtSxOJpK5VJpERCRD4pLiGLpiKGPXjyXVpJLDMwdD6w3ljepv6JxLkiWpNImIyB0xxjB772zeWPAGR2OPAvBkqScZ32g8+YPyW5xO5N5RaRIREYf9deEvui3oxtx9cwEoFFyIiY0n0rREU4uTidx7Kk0iInJbSalJRK+LZtjKYVxOuYynmye9a/ZmYJ2B+Hn6WR1P5L5QaRIRkVtadXgVneZ1Ytc/uwCoW7Auk5tOpnTu0hYnE7m/VJpEROSG/rn0D2/98hYzts4AILdfbj5o+AGvVHhF51ySbEmlSURE0rEbO1/8/gV9l/TlfMJ5ADpU6cCI+iPI6ZvT4nQi1nGz8s1HjhxJtWrVCAgIIE+ePDzxxBPs3bs33ToJCQl06dKF0NBQ/P39efrppzl16pRFiUVEsratJ7dSe2pt2s9tz/mE81QMq8i6tuuY8vgUFSbJ9iwtTStXrqRLly6sX7+eX375heTkZBo2bMilS5fS1unRowc///wzM2fOZOXKlRw/fpynnnrKwtQiIlnPxcSL9FrUiyqfVmHdsXX4e/kzNmosm9pv4uF8D1sdT8Qp2IwxxuoQV/3zzz/kyZOHlStXUqdOHWJiYsidOzfffPMNzzzzDAB79uyhdOnSrFu3jocfvv1f5NjYWIKCgoiJiSEwMPBeD0FExKUYY/hx94+8ufBN/r74NwDPlHmGsVFjyReYz+J0kp054/e3U81piomJASBnziubgDdv3kxycjKRkZFp65QqVYoCBQrctDQlJiaSmJiYdj82NvYepxYRcU2Hzh+i64KuzN8/H4DCwYWZ1GQSjYs3tjiZiHOydPfctex2O927d6dWrVqUK1cOgJMnT+Ll5UVwcHC6dcPCwjh58uQNX2fkyJEEBQWl3fLn19lpRUSulZSaxIjVIyjzcRnm75+Pp5snAx8ZyM7OO1WYRG7BabY0denShR07drBmzZq7ep3+/fvTs2fPtPuxsbEqTiIi/9+Kv1bQaV4n9pzZA8CjhR7l46YfUypXKYuTiTg/pyhNXbt2Ze7cuaxatYp8+f63Dz1v3rwkJSVx4cKFdFubTp06Rd68eW/4Wt7e3nh7e9/ryCIiLuX0pdP0Xtybr7Z9BUCeHHmIbhjNS+Vf0jmXRBxk6e45Ywxdu3blp59+YtmyZRQuXDjd8ipVquDp6cnSpUvTHtu7dy9HjhyhRo0a9zuuiIjLsRs7n2z6hJITS/LVtq+wYaNT1U7s6bKHlhVaqjCJ3AFLtzR16dKFb775htmzZxMQEJA2TykoKAhfX1+CgoJo27YtPXv2JGfOnAQGBtKtWzdq1Kjh0JFzIiLZ2R8n/qDTvE5s+HsDAJXzVmbK41N46IGHLE4m4posPeXAzf6FM23aNFq3bg1cObllr169+Pbbb0lMTCQqKoqPP/74prvn/s0ZD1kUEbmXLiZeZPDywXy08SPsxk6AVwDvPvYunat1xsPNKWZliNyWM35/O9V5mu4FZ/yli4jcC8YYftj9A28ufJPjF48D8FzZ5xgbNZaIgAiL04ncGWf8/tY/OUREsoCD5w7SdUFXFh5YCEDRkKJMajKJqGJRFicTyTpUmkREXFhiSiJjfh3De6vfIyElAS93L/rV6ke/2v3w9fS1Op5IlqLSJCLiopYdWkbneZ3Ze/bKhc7rF67Px00/pkRoCYuTiWRNKk0iIi7mVNwpei3uxdfbvwYgLEcYY6PG8kK5F3QKAZF7SKVJRMRFpNpT+WTzJwxYOoCYxBhs2OhcrTPvPvYuwT7BVscTyfJUmkREXMDvJ36n49yO/Hb8NwCqhFdhyuNTqBpR1eJkItmHSpOIiBOLSYhh0PJBTPptEnZjJ9A7kPcee49OVTvh7uZudTyRbEWlSUTECRlj+H7n9/RY1IMTcScAeKHcC0Q3jCY8INzidCLZk0qTiIiT2X92P10XdGXxwcUAFM9ZnElNJtGgaAOLk4lkbypNIiJOIiElgffXvM/INSNJTE3E292b/rX707d2X3w8fKyOJ5LtqTSJiDiBXw7+Qpf5Xdh/bj8ADYs2ZGLjiRQPLW5xMhG5SqVJRMRCJy6eoOfinvx3x38BCPcPZ2zUWJ4r+5zOuSTiZFSaREQskGpPZfKmyby97G1iE2Nxs7nRtVpXhj06jCCfIKvjicgNqDSJiNxnm45vouPcjmw+sRmAahHVmPL4FB4Mf9DiZCJyKypNIiL3ycXEiwxYOoBJv03CYAjyDmJE/RF0qNJB51wScQEqTSIi98GiA4toP7c9R2KOANCyfEs+aPgBef3zWpxMRByl0iQicg+du3yOnot6MmPrDAAKBxfm02afElkk0uJkInKnVJpERO6RH3f/SOd5nTl16RQ2bLxR/Q3ee+w9cnjlsDqaiGSASpOISCY7FXeKrgu68n+7/g+AUrlK8UXzL6iZv6bFyUTkbqg0iYhkEmMM/9n2H7ov6s65y+dwt7nTr3Y/BtYZqDN6i2QBKk0iIpngSMwROs7tyIIDCwColLcSU5tPpXJ4ZYuTiUhmUWkSEbkLdmPnk02f0GdJH+KS4vB292ZI3SH0rtkbT3dPq+OJSCZSaRIRyaD9Z/fz+s+vs+rwKgBq5q/JF82/oFSuUhYnE5F7QaVJROQOpdhTGLd+HIOWDyIhJQE/Tz9G1h9Jl2pddJJKkSxMpUlE5A5sP7WdtnPa8tvx3wCILBLJp49/SuGQwhYnE5F7TaVJRMQBSalJjFg9ghGrR5BsTybIO4joqGheq/QaNpvN6ngich+oNImI3MZvf/9Gmzlt2HF6BwAtSrbg46YfExEQYXEyEbmfVJpERG4iPjmeIcuHEL0+Gruxk9svNxMaT+C5ss9p65JINqTSJCJyA6sOr6LtnLYcOHcAgJfKv8T4RuPJ5ZfL4mQiYhWVJhGRa8QmxtJvST8mb5oMwAMBDzDl8Sk8XuJxi5OJiNVUmkRE/r8F+xfQYW4HjsYeBaDdg+0Y02AMQT5BFicTEWeg0iQi2d7Z+LP0WNSDr7Z9BUCRkCJ81uwzHiv8mMXJRMSZqDSJSLb2f7v+jy7zu3D60mls2Oj+cHeGPzqcHF45rI4mIk5GpUlEsqWTcSfpMr8LP+7+EYDSuUoztcVUHs73sMXJRMRZqTSJSLZijOHLrV/SY1EPziecx8PNg361+jGwzkC8PbytjiciTkylSUSyjSMxR2j/c3sWHVwEwIPhDzK1+VQq5q1ocTIRcQUqTSKS5dmNnSmbptB3SV/ikuLwdvdmaL2h9KrZCw83fQyKiGP0aSEiWdq+s/t4fc7rrD6yGoBa+WvxRfMvKJmrpMXJRMTVqDSJSJaUYk8hel00Q1YMISElgRyeORgVOYrO1TrjZnOzOp6IuCCVJhHJcrad2kbbOW3ZdHwTAA2KNODTZp9SKLiQtcFExKWpNIlIlpGYksiI1SMYsWYEKfYUgn2CiW4YTetKrXWBXRG5aypNIpIlbDi2gbZz2rLzn50APFHqCT5u8jHhAeEWJxORrEKlSURcWnxyPIOWDWLchnHYjZ3cfrmZ1GQSz5R5RluXRCRTqTSJiMta8dcKXp/zOgfPHwTg5QovMzZqLLn8clmcTESyIpUmEXE5sYmx9PmlD59s/gSAfIH5mNJ0Ck1LNLU4mYhkZSpNIuJS5u+fT4e5HTgWewyADlU6MLrBaAK9Ay1OJiJZnUqTiLiEs/Fn6b6oO//Z9h8AioYU5fPmn1OvUD1rg4lItqHSJCJOzRjDzF0z6Tq/K//E/4ObzY3u1bsz/LHh+Hn6WR1PRLIRlSYRcVonLp6g8/zOzNozC4CyucvyRfMvqJ6vurXBRCRbUmkSEadjjGH6lun0XNyTCwkX8HDzYEDtAQx4ZADeHt5WxxORbEqlSUScyl8X/qLD3A4sPrgYgCrhVZjaYioVwipYnExEsjtLr1q5atUqmjVrRkREBDabjVmzZqVb3rr1lUsfXHtr1KiRNWFF5J6yGzsTN06k3MflWHxwMT4ePrwf+T7rX1+vwiQiTsHSLU2XLl2iYsWKtGnThqeeeuqG6zRq1Ihp06al3ff21qZ5kaxm75m9vP7z66w5sgaARwo8wufNP6dEaAmLk4mI/I+lpalx48Y0btz4lut4e3uTN2/e+5RIRO6nFHsKH/76IUNWDCExNRF/L3/ej3yfjlU74mazdEO4iMh1nH5O04oVK8iTJw8hISE89thjvPvuu4SGhlodS0Tu0taTW2kzpw2/n/gdgKiiUXzy+CcUDC5ocTIRkRtz6tLUqFEjnnrqKQoXLszBgwcZMGAAjRs3Zt26dbi7u9/wOYmJiSQmJqbdj42NvV9xRcQBiSmJvLvqXUatHUWKPYUQnxDGRo3l1Yqv6gK7IuLUnLo0vfDCC2k/ly9fngoVKlC0aFFWrFhB/fr1b/ickSNHMnTo0PsVUUTuwPpj62k7py27/tkFwFOln2JSk0nk9dcueBFxfi41aaBIkSLkypWLAwcO3HSd/v37ExMTk3Y7evTofUwoIjdyKekSPRf1pOYXNdn1zy7y5MjDzGdn8sNzP6gwiYjLcOotTf927Ngxzp49S3h4+E3X8fb21hF2Ik5k+aHlvP7z6/x5/k8AXq34KtENown109xEEXEtlpamuLi4dFuNDh06xJYtW8iZMyc5c+Zk6NChPP300+TNm5eDBw/Sp08fihUrRlRUlIWpRcQRMQkx9PmlD5/+/ikA+QLz8enjn9K4+K2PmBURcVaWlqZNmzbx6KOPpt3v2bMnAK1atWLy5Mls27aNGTNmcOHCBSIiImjYsCHDhw/XliQRJzd331w6zu3I3xf/BqBT1U6MihxFoHegxclERDLOZowxVoe4l2JjYwkKCiImJobAQH1gi9xLZ+LP8ObCN/lm+zcAFMtZjM+bfU7dQnUtTiYirsYZv78zNBF8xowZzJs3L+1+nz59CA4OpmbNmhw+fDjTwomIazDG8N2O7ygzqQzfbP8GN5sbvWv0ZmvHrSpMIpJlZKg0jRgxAl9fXwDWrVvHpEmTGD16NLly5aJHjx6ZGlBEnNvxi8d58rsneeGHF/gn/h/K5SnH+rbrGdNwDH6eflbHExHJNBma03T06FGKFSsGwKxZs3j66adp3749tWrVol69epmZT0SclDGGqX9MpdfiXsQkxuDp5smARwYw4JEBeLl7WR1PRCTTZag0+fv7c/bsWQoUKMDixYvTJnD7+Phw+fLlTA0oIs7n0PlDtJ/bniV/LgGgWkQ1vmj+BeXDylucTETk3slQaWrQoAGvv/46lStXZt++fTRp0gSAnTt3UrCgrhslklXZjZ2JGycyYOkALiVfwsfDh+GPDqf7w93xcHOp076JiNyxDH3KTZo0iYEDB3L06FF++OGHtAvobt68mZdeeilTA4qIc9hzZg9t57Tl16O/AlCnYB0+b/Y5xUOLW5xMROT+yPApBxISEti2bRunT5/GbrenW9a8efNMCZcZnPGQRRFXkpyazAe/fsDQlUNJTE3E38uf0ZGj6VC1A242l7oSk4i4EGf8/s7QlqaFCxfy6quvcvbsWf7duWw2G6mpqZkSTkSsteXkFtrMbsMfJ/8AoFGxRnzy+CcUCCpgcTIRkfsvQ/9M7NatG88++yzHjx/Hbrenu6kwibi+hJQEBi4bSLXPqvHHyT8I8QlhxhMzmP/SfBUmEcm2MrSl6dSpU/Ts2ZOwsLDMziMiFlt3dB1t5rRhz5k9ADxT5hkmNp5ImL/+votI9pahLU3PPPMMK1asyOQoImKlS0mX6L6wO7Wm1mLPmT2E5Qjjh+d+YOazM1WYRETI4ETw+Ph4nn32WXLnzk358uXx9PRMt/yNN97ItIB3yxknkok4mxV/raDN7DYcunAIgFYVWxEdFU1O35wWJxOR7MoZv78ztHvu22+/ZfHixfj4+LBixQpsNlvaMpvN5lSlSURu7lLSJfov7c+EjRMAKBBUgE8e/4RGxRpZnExExPlkqDS9/fbbDB06lH79+uHmpkOORVzR6sOreW32axw8fxCA9g+2Z0zDMQR6O8e/6EREnE2GSlNSUhLPP/+8CpOIC4pPjuftpW8zfsN4DIb8gfn5vPnnNCza0OpoIiJOLUOtp1WrVnz33XeZnUVE7rFfj/5KpSmVGLdhHAZD28pt2d5puwqTiIgDMrSlKTU1ldGjR7No0SIqVKhw3UTw6OjoTAknIpnjcvJlBi8fzIfrPsRgiAiI4PNmn9O4eGOro4mIuIwMlabt27dTuXJlAHbs2JFu2bWTwkXEehuObaDVrFbsPbsXgNaVWjM2aizBPsHWBhMRcTEZKk3Lly/P7BwikskSUhJ4Z8U7jPl1DHZjJ9w/nE+bfcrjJR63OpqIiEvKUGkSEef229+/0Xp2a3b9swuAlyu8zPhG43XeJRGRu6DSJJKFJKYkMmzlMN5f+z6pJpWwHGF88vgntCjVwupoIiIuT6VJJIv4/cTvtJrVih2nr8wzfLHci0xoPIFQv1CLk4mIZA0qTSIuLik1ifdWvcd7q98j1aSS2y83k5tO5ukyT1sdTUQkS1FpEnFhW09updWsVmw9tRWAZ8s8y6Qmk8idI7fFyUREsh6VJhEXlJyazKg1oxi2ahgp9hRCfUP5uOnHPFf2OaujiYhkWSpNIi5m+6nttJ7dmt9P/A7Ak6WeZHLTyYT5h1mcTEQka1NpEnERKfYURq8dzTsr3iHZnkyITwiTmkzihXIv6KSyIiL3gUqTiAvY9c8uWs1qxabjmwBoXrI5U5pOITwg3OJkIiLZh0qTiBNLsafw4a8fMnjFYJJSkwj2CWZC4wm0LN9SW5dERO4zlSYRJ7XnzB5az2rNhr83ANC0eFM+bfYpEQERFicTEcmeVJpEnEyqPZVx68fx9rK3SUxNJNA7kPGNxtOqYittXRIRsZBKk4gT2Xd2H6/Nfo1fj/4KQKNijfis2WfkC8xncTIREVFpEnECdmPnow0f0X9pfxJSEgjwCmBs1FjaVG6jrUsiIk5CpUnEYgfOHaDN7DasPrIagMgikXzR/AsKBBWwOJmIiFxLpUnEInZjZ9LGSfRb2o/45Hj8vfz5oMEHtK/SXluXRESckEqTiAX+PP8nbWa3YeXhlQA8WuhRpraYSqHgQtYGExGRm1JpErmP7MbOJ5s+4a1f3uJS8iX8PP0Y02AMHat2xM3mZnU8ERG5BZUmkfvk8IXDtJ3TlqWHlgJQp2AdprWYRpGQIhYnExERR6g0idxjxhg+//1zei7uSVxSHL4evoyKHEXXh7pq65KIiAtRaRK5h47GHOX1n19n8cHFANTKX4vpT0ynWM5iFicTEZE7pdIkcg8YY5i2ZRo9FvUgNjEWHw8fRjw2gjeqv4G7m7vV8UREJANUmkQy2d+xf9Pu53YsOLAAgBr5ajCtxTRK5ippcTIREbkbKk0imcQYw5dbv+TNhW8SkxiDt7s37z72Lj0e7qGtSyIiWYBKk0gmOH7xOB3mdmDuvrkAPPTAQ0xvMZ3SuUtbnExERDKLSpPIXTDG8M32b+i2oBvnE87j5e7F0HpD6V2zNx5u+uslIpKV6FNdJINOxp2k49yOzN47G4Aq4VWY/sR0yuUpZ3EyERG5F1SaRO6QMYbvdn5Hl/ldOHf5HJ5ungypO4Q+tfrg6e5pdTwREblHVJpE7sDpS6fpPK8zP+z+AYBKeSsx44kZVAirYHEyERG511SaRBw0c+dMOs/vzJn4M3i4eTCoziD61+6vrUsiItmESpPIbZyJP0PX+V35bud3AFQIq8CMJ2ZQKW8la4OJiMh9pdIkcgs/7f6JjvM6cvrSadxt7gx4ZAAD6wzEy93L6mgiInKfqTSJ3MDZ+LO8sfANvtn+DQBlc5dlxhMzqBJRxeJkIiJiFUsvsb5q1SqaNWtGREQENpuNWbNmpVtujGHw4MGEh4fj6+tLZGQk+/fvtyasZBtz9s6h3ORyfLP9G9xsbvSv3Z/N7TerMImIZHOWlqZLly5RsWJFJk2adMPlo0eP5qOPPmLKlCls2LCBHDlyEBUVRUJCwn1OKtnB+cvnefWnV2nx3xacjDtJ6VylWdd2HSPqj8Dbw9vqeCIiYjFLd881btyYxo0b33CZMYZx48YxcOBAWrRoAcCXX35JWFgYs2bN4oUXXrifUSWLm7dvHu3ntuf4xeO42dzoXaM3Qx8dio+Hj9XRRETESTjtnKZDhw5x8uRJIiMj0x4LCgqievXqrFu37qalKTExkcTExLT7sbGx9zyruK4LCRfouagn07ZMA6BEaAmmt5hOjfw1LE4mIiLOxtLdc7dy8uRJAMLCwtI9HhYWlrbsRkaOHElQUFDaLX/+/Pc0p7iuRQcWUX5yeaZtmYYNG71q9GJLhy0qTCIickNOW5oyqn///sTExKTdjh49anUkcTKxibG0m9OORl834ljsMYrlLMbq11bzQcMP8PX0tTqeiIg4KafdPZc3b14ATp06RXh4eNrjp06dolKlSjd9nre3N97emrQrN7bkzyW0md2Go7FHsWHjjepvMKL+CPw8/ayOJiIiTs5ptzQVLlyYvHnzsnTp0rTHYmNj2bBhAzVqaPeJ3JmLiRfpOLcjDb5qwNHYoxQJKcKK1isY12icCpOIiDjE0i1NcXFxHDhwIO3+oUOH2LJlCzlz5qRAgQJ0796dd999l+LFi1O4cGEGDRpEREQETzzxhHWhxeUsO7SMNrPbcDjmMABdq3VlVOQocnjlsDiZiIi4EktL06ZNm3j00UfT7vfs2ROAVq1aMX36dPr06cOlS5do3749Fy5coHbt2ixcuBAfHx0GLrcXlxRHvyX9mPTblfOAFQouxNTmU3m08KO3eaaIiMj1bMYYY3WIeyk2NpagoCBiYmIIDAy0Oo7cJ6sOr+K12a/x5/k/AehYpSOjG4wmwDvA4mQiIuIIZ/z+dtqJ4CIZEZ8cz4ClAxi/YTwABYIK8EXzL4gsEnmbZ4qIiNyaSpNkGWuOrOG12a9x4NyVeXLtHmzHBw0/INDbOf6FIiIirk2lSVze5eTLDFw2kLHrx2Iw5AvMx+fNPieqWJTV0UREJAtRaRKXtu7oOlrPbs2+s/sAaFOpDdFR0QT5BFmcTEREshqVJnFJCSkJDF4+mA/XfYjd2IkIiOCzZp/RpHgTq6OJiEgWpdIkLmfj3xtpPas1u8/sBuDViq8yLmocIb4hFicTEZGsTKVJXEZiSiJDVw7l/bXvYzd28vrn5dPHP6VZyWZWRxMRkWxApUlcwubjm2k1qxU7/9kJQMvyLfmo8Ufk9M1pcTIREckuVJrEqSWkJPDeqvcYuWYkqSaVPDnyMKXpFJ4s/aTV0UREJJtRaRKnZIzh+53f03dJ37Rrxj1f9nkmNplILr9cFqcTEZHsSKVJnM6GYxvosagH646tA+CBgAcY12gcz5R5xuJkIiKSnak0idM4fOEw/Zf259sd3wLg5+lHv1r96FWzF36efhanExGR7E6lSSx3MfEio9aMInp9NAkpCdiw0bpSa9597F0iAiKsjiciIgKoNImFUu2pTNsyjYHLBnLq0ikA6hWqR3TDaCqHV7Y4nYiISHoqTWKJJX8uoeeinmw/vR2A4jmLM6bBGJqXbI7NZrM4nYiIyPVUmuS+2v3Pbt765S3m7Z8HQIhPCEPqDqFTtU54uXtZnE5EROTmVJrkvjgTf4ahK4YyedNkUk0qHm4edKnWhcF1B+sElSIi4hJUmuSeSkxJZOLGiQxfNZyYxBgAWpRswegGoykRWsLidCIiIo5TaZJ7whjDT3t+os8vfTh4/iAAlfJW4sOGH/JY4ccsTiciInLnVJok020+vpmei3uy6vAqAPL652XEYyN4teKruLu5W5xOREQkY1SaJNMciz3G28ve5sutXwLg6+FL75q96VOrD/5e/hanExERuTsqTXLXLiVdYvTa0Yz5dQyXUy4D8HKFlxnx2AjyB+W3OJ2IiEjmUGmSDLMbO19u/ZIBSwdwIu4EALUL1Ca6YTTVHqhmcToREZHMpdIkGbLirxX0XNSTP07+AUCRkCKMjhzNU6Wf0skpRUQkS1Jpkjuy/+x++izpw6w9swAI9A5kUJ1BdHuoG94e3taGExERuYdUmsQh5y+fZ/iq4UzcOJFkezLuNnc6VOnAO/XeIXeO3FbHExERuedUmuSWklOTmbxpMkNXDuXc5XMANCnehDENxlAmdxmL04mIiNw/Kk1yQ8YY5u6bS+9ferPv7D4AyuYuS3RUNA2LNrQ4nYiIyP2n0iTX2XJyC70W92LZoWUA5MmRh+GPDqdN5TZ4uOmPjIiIZE/6BpQ0Jy6eYNDyQUz9YyoGg7e7Nz0e7kH/R/oT6B1odTwRERFLqTQJ8cnxRK+LZtSaUVxKvgTAC+VeYGT9kRQKLmRtOBERESeh0pSN2Y2db7d/S7+l/TgWewyA6g9UZ2zUWGrkr2FxOhEREeei0pRNrTmyhp6LevLb8d8AKBBUgPcj3+f5ss/r5JQiIiI3oNKUzfx5/k/6LenHzF0zAfD38mdA7QF0f7g7vp6+FqcTERFxXipN2URMQgzvrX6P8RvGk5SahJvNjdcrv86wR4cR5h9mdTwRERGnp9KUxaXYU/hs82cMXjGYM/FnAGhQpAEfNvyQ8mHlLU4nIiLiOlSasrAF+xfQ+5fe7PpnFwClcpXiw4Yf0rhYY81bEhERuUMqTVnQjtM76L24N4sOLgIg1DeUofWG0r5KezzdPS1OJyIi4ppUmrKQ05dOM3j5YD77/TPsxo6nmydvVH+DgXUGEuwTbHU8ERERl6bSlAUkpCQwfv143lv9HheTLgLwdOmneT/yfYrmLGpxOhERkaxBpcmFGWOYuWsmfZf05a8LfwFQNaIq0Q2jeaTgI9aGExERyWJUmlzUhmMb6LGoB+uOrQPggYAHGFl/JC0rtMTN5mZxOhERkaxHpcnFHL5wmP5L+/Ptjm8B8PP0o1+tfvSq2Qs/Tz+L04mIiGRdKk0u4mLiRUatGUX0+mgSUhKwYaN1pda8+9i7RAREWB1PREQky1NpcnKp9lSmbZnGwGUDOXXpFAD1CtUjumE0lcMrW5xOREQk+1BpcmJL/lxCz0U92X56OwDFchbjgwYf0Lxkc52cUkRE5D5TaXJCe87s4a1f3mLuvrkAhPiEMKTuEDpV64SXu5fF6URERLInlSYncib+DENXDGXypsmkmlQ83DzoUq0Lg+sOJqdvTqvjiYiIZGsqTU4gMSWRiRsnMnzVcGISYwBoXrI5YxqMoURoCYvTiYiICKg0WcoYw097fqLPL304eP4gABXDKhIdFc1jhR+zOJ2IiIhcS6XJIpuPb6bn4p6sOrwKgLz+eXnvsfdoVbEV7m7uFqcTERGRf3PqU0e/88472Gy2dLdSpUpZHeuu/B37N61mtaLqZ1VZdXgVvh6+DKoziP3d9tOmchsVJhERESfl9FuaypYty5IlS9Lue3g4feQbupR0iTG/jmH02tFcTrkMwMsVXmbEYyPIH5Tf4nQiIiJyO07fQDw8PMibN6/VMTLMbux8ufVLBiwdwIm4EwDULlCb6IbRVHugmsXpRERExFFOX5r2799PREQEPj4+1KhRg5EjR1KgQAGrYzlkxV8r6LmoJ3+c/AOAwsGFGd1gNE+XflonpxQREXExTl2aqlevzvTp0ylZsiQnTpxg6NChPPLII+zYsYOAgIAbPicxMZHExMS0+7Gxsfcrbpr9Z/fTZ0kfZu2ZBUCgdyCD6gyi20Pd8Pbwvu95RERE5O7ZjDHG6hCOunDhAgULFiQ6Opq2bdvecJ133nmHoUOHXvd4TEwMgYGB9zTf+cvnGb5qOBM3TiTZnoy7zZ0OVTrwTr13yJ0j9z19bxERkawkNjaWoKCg+/L97Sin3tL0b8HBwZQoUYIDBw7cdJ3+/fvTs2fPtPuxsbHkz39vJ1onpyYzedNkhq4cyrnL5wBoUrwJYxqMoUzuMvf0vUVEROT+cKnSFBcXx8GDB3nllVduuo63tzfe3vdnF5gxhrn75tL7l97sO7sPgLK5yxIdFU3Dog3vSwYRERG5P5y6NPXu3ZtmzZpRsGBBjh8/zpAhQ3B3d+fFF1+0OhpbTm6h1+JeLDu0DIA8OfIw/NHhtKncBg83p/61ioiISAY49bf7sWPHePHFFzl79iy5c+emdu3arF+/nty5rZ8f1Htxb5YdWoa3uzc9Hu5B/0f6E+jtHPtcRUREJPO51ETwjLhXE8n+OPEHo38dzcj6IykUXCjTXldERESccyK4SpOIiIg4HWf8/nbqa8+JiIiIOAuVJhEREREHqDSJiIiIOEClSURERMQBKk0iIiIiDlBpEhEREXGASpOIiIiIA1SaRERERByg0iQiIiLiAJUmEREREQeoNImIiIg4QKVJRERExAEqTSIiIiIOUGkSERERcYBKk4iIiIgDVJpEREREHKDSJCIiIuIAD6sDiIiIiHNISoKLF+/81rUr1K9vdfp7T6VJRETERdntEBd38zITG3tn5ScpKWM5oqJUmkRERCQTGQOXL2dsa86Nbpcu3Zuc3t4QEJD+Fhh4/WNXb3Xq3JsczkalSURE5BaSkzO21eZGt7g4SE3N/Izu7jcvNBm5eXpmfsasQKVJRESynPh4OH8+c7bmJCbem4w5cji+Jed2Nx8fsNnuTU75H5UmERFxagkJcPYsnDlz/e1mj1++nPk5brTLKqM3f39w0/HrLkelSURE7pukpBsXnZuVnzNnMj5vx83tzublaJeV3I5Kk4iIZEhyMpw751jxubo8NjZj7+XuDqGhkCvX9bebPR4QoF1WkrlUmkREhNRUxwrQtY9fuJCx93Jzg5w576wABQZqd5ZYT6VJRCSLsduvTIK+3dafa5edP3/lcPg7ZbNBSIjj5SdXLggOVgES16TSJCLixOx2iIlxvPycOXNli5HdnrH3Cw52vACFhl4pTB76JpFsQn/URUTuE2OuzOm5kyPBzp7N+Hl9AgMd3/oTGnpll5kmO4vcnEqTiEgGpKZe2aV19uz/ys7tfj57FlJSMvZ+/v53VoBCQ8HLK3PHLJLdqTSJSLZ39TB4R0rP1fsZnQME4OsLuXM7Vn6u/tfHJ3PHLCJ3TqVJRLKU+PibF52b/XzxYsbf7+ousKtbd27287X3fX0zb7wicv+oNImIU/r3/B9HC1BCQsbez2b732HwNyo6N/pZc4BEsheVJhG5566d/+PIrq+7nf/j6elY6bn2Zx0GLyK3o9IkInfE0fk/1/58N/N//PzuvAD5++tM0CKS+VSaRLKp5OQrZ3S+9narrT6ZMf8nKOjO5v5o/o+IOBOVJhEXlZJy5aSH589fX34cuWX0Iqjwv8tgOFJ6dA4gEckqVJpELJKScmWi89USc6flJy4uc3IEBFyZz+PIVqCrP2v+j4hkRypNIhmUmpq+9Nxp8bmb3VzX8ve/UmL+fQsJufHj194CA3UJDBERR+njUrItu/360nMn5Sc2NnNy5Mhx81Jzu+ITFKTSIyJyv+jjVlyW3X5la83NSs3tik9sbMaP6LqWn9+ti82tik9QkOb5iIi4CpUmsZQxV4rPuXO3vt2oAMXEZE7p8fW9/W6smxWfoCBd30tEJLtQaZJMYbdfKTG3Kj5nz9748Yxewf0qH5/bl56blZ+gIPD2vrv3FxGR7EGlSdJJSbmyFedOi8/dnLwQrhSXq0do5cyZ/hYSkr7sXPtzUJAuZCoiIveHSlMWlZR0pcg4Wnqu3mJi7u59c+S4cfG53U0nMBQREWen0uTkEhLuvPicO3f35/AJCrrz4hMSol1dIiKSdak03QfGQHz8nZWeq7fLlzP+vjbblSKTkfKjw9hFRETS01djBh08CH/+6Xj5SUrK+Hu5u9958cmZU2dtFhERyUwqTRn03nswbdqdPcfT887n+4SGXrnMha7YLiIiYi2VpgwqUgTKl3e8+OTMeeUkiCo/IiIirslmTGacHtB5xcbGEhQURExMDIGBgVbHEREREQc44/e3S8x4mTRpEoUKFcLHx4fq1auzceNGqyOJiIhINuP0pem7776jZ8+eDBkyhN9//52KFSsSFRXF6dOnrY4mIiIi2YjTl6bo6GjatWvHa6+9RpkyZZgyZQp+fn5MnTrV6mgiIiKSjTh1aUpKSmLz5s1ERkamPebm5kZkZCTr1q274XMSExOJjY1NdxMRERG5W05dms6cOUNqaiphYWHpHg8LC+PkyZM3fM7IkSMJCgpKu+XPn/9+RBUREZEszqlLU0b079+fmJiYtNvRo0etjiQiIiJZgFOfpylXrly4u7tz6tSpdI+fOnWKvHnz3vA53t7eeOsCaCIiIpLJnHpLk5eXF1WqVGHp0qVpj9ntdpYuXUqNGjUsTCYiIiLZjVNvaQLo2bMnrVq1omrVqjz00EOMGzeOS5cu8dprr1kdTURERLIRpy9Nzz//PP/88w+DBw/m5MmTVKpUiYULF143OVxERETkXtJlVERERMTpOOP3t1PPaRIRERFxFipNIiIiIg5QaRIRERFxgNNPBL9bV6ds6XIqIiIiruPq97YzTb3O8qXp4sWLALqcioiIiAs6e/YsQUFBVscAssHRc3a7nePHjxMQEIDNZrM6jkNiY2PJnz8/R48edZojBjJTVh8fZP0xanyuL6uPMauPD7L+GGNiYihQoADnz58nODjY6jhANtjS5ObmRr58+ayOkSGBgYFZ8i/CVVl9fJD1x6jxub6sPsasPj7I+mN0c3Oe6dfOk0RERETEiak0iYiIiDhApckJeXt7M2TIELy9va2Ock9k9fFB1h+jxuf6svoYs/r4IOuP0RnHl+UngouIiIhkBm1pEhEREXGASpOIiIiIA1SaRERERByg0nQPjBo1CpvNRvfu3dMeS0hIoEuXLoSGhuLv78/TTz/NqVOn0j3vyJEjNG3aFD8/P/LkycNbb71FSkpKunVWrFjBgw8+iLe3N8WKFWP69OnXvf+kSZMoVKgQPj4+VK9enY0bN971mP7++29efvllQkND8fX1pXz58mzatCltuTGGwYMHEx4ejq+vL5GRkezfvz/da5w7d46WLVsSGBhIcHAwbdu2JS4uLt0627Zt45FHHsHHx4f8+fMzevTo67LMnDmTUqVK4ePjQ/ny5Zk/f/5djy81NZVBgwZRuHBhfH19KVq0KMOHD093+n5XGuOqVato1qwZERER2Gw2Zs2alW65M43FkSx3Osbk5GT69u1L+fLlyZEjBxEREbz66qscP37cZcZ4u/+H1+rYsSM2m41x48ZlqfHt3r2b5s2bExQURI4cOahWrRpHjhxJW+7sn6u3G2NcXBxdu3YlX758+Pr6UqZMGaZMmZJuHWce48iRI6lWrRoBAQHkyZOHJ554gr179zptfkey3JaRTLVx40ZTqFAhU6FCBfPmm2+mPd6xY0eTP39+s3TpUrNp0ybz8MMPm5o1a6YtT0lJMeXKlTORkZHmjz/+MPPnzze5cuUy/fv3T1vnzz//NH5+fqZnz55m165dZsKECcbd3d0sXLgwbZ3//ve/xsvLy0ydOtXs3LnTtGvXzgQHB5tTp05leEznzp0zBQsWNK1btzYbNmwwf/75p1m0aJE5cOBA2jqjRo0yQUFBZtasWWbr1q2mefPmpnDhwuby5ctp6zRq1MhUrFjRrF+/3qxevdoUK1bMvPjii2nLY2JiTFhYmGnZsqXZsWOH+fbbb42vr6/55JNP0tZZu3atcXd3N6NHjza7du0yAwcONJ6enmb79u0ZHp8xxrz33nsmNDTUzJ071xw6dMjMnDnT+Pv7m/Hjx7vkGOfPn2/efvtt8+OPPxrA/PTTT+mWO9NYHMlyp2O8cOGCiYyMNN99953Zs2ePWbdunXnooYdMlSpV0r2GM4/xdv8Pr/rxxx9NxYoVTUREhBk7dmyWGd+BAwdMzpw5zVtvvWV+//13c+DAATN79ux0n2XO/rl6uzG2a9fOFC1a1CxfvtwcOnTIfPLJJ8bd3d3Mnj3bJcYYFRVlpk2bZnbs2GG2bNlimjRpYgoUKGDi4uKcMv/tsjhCpSkTXbx40RQvXtz88ssvpm7dumml6cKFC8bT09PMnDkzbd3du3cbwKxbt84Yc+Uvl5ubmzl58mTaOpMnTzaBgYEmMTHRGGNMnz59TNmyZdO95/PPP2+ioqLS7j/00EOmS5cuafdTU1NNRESEGTlyZIbH1bdvX1O7du2bLrfb7SZv3rxmzJgxaY9duHDBeHt7m2+//dYYY8yuXbsMYH777be0dRYsWGBsNpv5+++/jTHGfPzxxyYkJCRtvFffu2TJkmn3n3vuOdO0adN071+9enXToUOHDI/PGGOaNm1q2rRpk+6xp556yrRs2dLlx/jvD2tnGosjWTIyxhvZuHGjAczhw4ddbow3G9+xY8fMAw88YHbs2GEKFiyYrjS5+vief/558/LLL9/0Oa72uXqjMZYtW9YMGzYs3WMPPvigefvtt11yjKdPnzaAWblypdPldySLI7R7LhN16dKFpk2bEhkZme7xzZs3k5ycnO7xUqVKUaBAAdatWwfAunXrKF++PGFhYWnrREVFERsby86dO9PW+fdrR0VFpb1GUlISmzdvTreOm5sbkZGRaetkxJw5c6hatSrPPvssefLkoXLlynz22Wdpyw8dOsTJkyfTvW9QUBDVq1dPN77g4GCqVq2atk5kZCRubm5s2LAhbZ06derg5eWVbnx79+7l/PnzDv0OMqpmzZosXbqUffv2AbB161bWrFlD48aNs8wYr3KmsTiSJbPExMRgs9nSrmHl6mO02+288sorvPXWW5QtW/a65a48Prvdzrx58yhRogRRUVHkyZOH6tWrp9u95eqfq3Dlc2fOnDn8/fffGGNYvnw5+/bto2HDhi45xpiYGABy5szpdPkdyeIIlaZM8t///pfff/+dkSNHXrfs5MmTeHl5XXfBwbCwME6ePJm2zrV/aK4uv7rsVuvExsZy+fJlzpw5Q2pq6g3XufoaGfHnn38yefJkihcvzqJFi+jUqRNvvPEGM2bMSJfvVu978uRJ8uTJk265h4cHOXPmzJTfwd2MD6Bfv3688MILlCpVCk9PTypXrkz37t1p2bJllhnjVc40FkeyZIaEhAT69u3Liy++mHaNLlcf4/vvv4+HhwdvvPHGDZe78vhOnz5NXFwco0aNolGjRixevJgnn3ySp556ipUrV6a9ryt/rgJMmDCBMmXKkC9fPry8vGjUqBGTJk2iTp06LjdGu91O9+7dqVWrFuXKlXO6/I5kcUSWv2Dv/XD06FHefPNNfvnlF3x8fKyOk+nsdjtVq1ZlxIgRAFSuXJkdO3YwZcoUWrVqZXG6zPH999/z9ddf880331C2bFm2bNlC9+7diYiIyDJjzK6Sk5N57rnnMMYwefJkq+Nkis2bNzN+/Hh+//13bDab1XEynd1uB6BFixb06NEDgEqVKvHrr78yZcoU6tata2W8TDNhwgTWr1/PnDlzKFiwIKtWraJLly5ERERct2XF2XXp0oUdO3awZs0aq6PcU9rSlAk2b97M6dOnefDBB/Hw8MDDw4OVK1fy0Ucf4eHhQVhYGElJSVy4cCHd806dOkXevHkByJs373Wz+K/ev906gYGB+Pr6kitXLtzd3W+4ztXXyIjw8HDKlCmT7rHSpUunHcVy9bVv9b558+bl9OnT6ZanpKRw7ty5TPkd3M34AN566620rU3ly5fnlVdeoUePHmlbDrPCGK9yprE4kuVuXC1Mhw8f5pdffkl3JXhXHuPq1as5ffo0BQoUSPvMOXz4ML169aJQoUIuP75cuXLh4eFx288dV/5cvXz5MgMGDCA6OppmzZpRoUIFunbtyvPPP88HH3zgUmPs2rUrc+fOZfny5eTLly/tcWfK70gWR6g0ZYL69euzfft2tmzZknarWrUqLVu2TPvZ09OTpUuXpj1n7969HDlyhBo1agBQo0YNtm/fnu5D7uqH/NUPjho1aqR7javrXH0NLy8vqlSpkm4du93O0qVL09bJiFq1al13GOm+ffsoWLAgAIULFyZv3rzp3jc2NpYNGzakG9+FCxfYvHlz2jrLli3DbrdTvXr1tHVWrVpFcnJyuvGVLFmSkJAQh34HGRUfH4+bW/q/Du7u7mn/4s0KY7zKmcbiSJaMulqY9u/fz5IlSwgNDU233JXH+Morr7Bt27Z0nzkRERG89dZbLFq0yOXH5+XlRbVq1W75uVOlShWX/lxNTk4mOTn5lp87zj5GYwxdu3blp59+YtmyZRQuXDjdcmfK70gWhzg8ZVzuyLVHzxlz5VDHAgUKmGXLlplNmzaZGjVqmBo1aqQtv3rYZcOGDc2WLVvMwoULTe7cuW942OVbb71ldu/ebSZNmnTDwy69vb3N9OnTza5du0z79u1NcHBwuiMT7tTGjRuNh4eHee+998z+/fvN119/bfz8/Mx//vOftHVGjRplgoODzezZs822bdtMixYtbngIe+XKlc2GDRvMmjVrTPHixdMd/nzhwgUTFhZmXnnlFbNjxw7z3//+1/j5+V13+LOHh4f54IMPzO7du82QIUMy5ZQDrVq1Mg888EDaKQd+/PFHkytXLtOnTx+XHOPFixfNH3/8Yf744w8DmOjoaPPHH3+kHTnmTGNxJMudjjEpKck0b97c5MuXz2zZssWcOHEi7XbtkWLOPMbb/T/8t38fPefq4/vxxx+Np6en+fTTT83+/fvTDjNfvXp12ms4++fq7cZYt25dU7ZsWbN8+XLz559/mmnTphkfHx/z8ccfu8QYO3XqZIKCgsyKFSvS/R2Lj493yvy3y+IIlaZ75N+l6fLly6Zz584mJCTE+Pn5mSeffNKcOHEi3XP++usv07hxY+Pr62ty5cplevXqZZKTk9Ots3z5clOpUiXj5eVlihQpYqZNm3bde0+YMMEUKFDAeHl5mYceesisX7/+rsfz888/m3Llyhlvb29TqlQp8+mnn6ZbbrfbzaBBg0xYWJjx9vY29evXN3v37k23ztmzZ82LL75o/P39TWBgoHnttdfMxYsX062zdetWU7t2bePt7W0eeOABM2rUqOuyfP/996ZEiRLGy8vLlC1b1sybN++uxxcbG2vefPNNU6BAAePj42OKFCli3n777XRfsK40xuXLlxvgulurVq2cbiyOZLnTMR46dOiGywCzfPlylxjj7f4f/tuNSpOrj++LL74wxYoVMz4+PqZixYpm1qxZ6V7D2T9XbzfGEydOmNatW5uIiAjj4+NjSpYsaT788ENjt9tdYow3+zt27Ws7U35HstyO7f8PXERERERuQXOaRERERByg0iQiIiLiAJUmEREREQeoNImIiIg4QKVJRERExAEqTSIiIiIOUGkSERERcYBKk4iIiIgDVJpExGXUq1eP7t27u9xri0jW4GF1ABERZ/Djjz/i6elpdQwRcWIqTSIiQM6cOa2OICJOTrvnRMQhn376KREREdjt9nSPt2jRgjZt2gAwefJkihYtipeXFyVLluSrr75Kt+6FCxfo0KEDYWFh+Pj4UK5cOebOnQvA2bNnefHFF3nggQfw8/OjfPnyfPvtt9flSElJoWvXrgQFBZErVy4GDRqEo5fQ/PjjjylevDg+Pj6EhYXxzDPPpC27dvfcihUrsNls191at26dtv7s2bN58MEH8fHxoUiRIgwdOpSUlBSHcoiIa9KWJhFxyLPPPku3bt1Yvnw59evXB+DcuXMsXLiQ+fPn89NPP/Hmm28ybtw4IiMjmTt3Lq+99hr58uXj0UcfxW6307hxYy5evMh//vMfihYtyq5du3B3dwcgISGBKlWq0LdvXwIDA5k3bx6vvPIKRYsW5aGHHkrLMWPGDNq2bcvGjRvZtGkT7du3p0CBArRr1+6W+Tdt2sQbb7zBV199Rc2aNTl37hyrV6++4bo1a9bkxIkTafd3795NkyZNqFOnDgCrV6/m1Vdf5aOPPuKRRx7h4MGDtG/fHoAhQ4Zk/JcsIs7NiIg4qEWLFqZNmzZp9z/55BMTERFhUlNTTc2aNU27du3Srf/ss8+aJk2aGGOMWbRokXFzczN79+51+P2aNm1qevXqlXa/bt26pnTp0sZut6c91rdvX1O6dOnbvtYPP/xgAgMDTWxs7A2X161b17z55pvXPX7mzBlTpEgR07lz57TH6tevb0aMGJFuva+++sqEh4ffNoeIuC7tnhMRh7Vs2ZIffviBxMREAL7++mteeOEF3Nzc2L17N7Vq1Uq3fq1atdi9ezcAW7ZsIV++fJQoUeKGr52amsrw4cMpX748OXPmxN/fn0WLFnHkyJF06z388MPYbLa0+zVq1GD//v2kpqbeMnuDBg0oWLAgRYoU4ZVXXuHrr78mPj7+ls9JTk7m6aefpmDBgowfPz7t8a1btzJs2DD8/f3Tbu3atePEiRO3fU0RcV0qTSLisGbNmmGMYd68eRw9epTVq1fTsmVLh57r6+t7y+Vjxoxh/Pjx9O3bl+XLl7NlyxaioqJISkrKjOgEBATw+++/8+233xIeHs7gwYOpWLEiFy5cuOlzOnXqxNGjR5k5cyYeHv+bzRAXF8fQoUPZsmVL2m379u3s378fHx+fTMkrIs5Hc5pExGE+Pj489dRTfP311xw4cICSJUvy4IMPAlC6dGnWrl1Lq1at0tZfu3YtZcqUAaBChQocO3aMffv23XBr09q1a2nRogUvv/wyAHa7nX379qU9/6oNGzaku79+/XqKFy+eNjfqVjw8PIiMjCQyMpIhQ4YQHBzMsmXLeOqpp65bNzo6mu+//55ff/2V0NDQdMsefPBB9u7dS7FixW77niKSdag0icgdadmyJY8//jg7d+5MKzgAb731Fs899xyVK1cmMjKSn3/+mR9//JElS5YAULduXerUqcPTTz9NdHQ0xYoVY8+ePdhsNho1akTx4sX5v//7P3799VdCQkKIjo7m1KlT15WmI0eO0LNnTzp06MDvv//OhAkT+PDDD2+be+7cufz555/UqVOHkJAQ5s+fj91up2TJktetu2TJEvr06cOkSZPIlSsXJ0+eBK5sLQsKCmLw4ME8/vjjFChQgGeeeQY3Nze2bt3Kjh07ePfdd+/m1ysizszqSVUi4lpSU1NNeHi4AczBgwfTLfv4449NkSJFjKenpylRooT58ssv0y0/e/asee2110xoaKjx8fEx5cqVM3Pnzk1b1qJFC+Pv72/y5MljBg4caF599VXTokWLtOfXrVvXdO7c2XTs2NEEBgaakJAQM2DAgHQTw29m9erVpm7duiYkJMT4+vqaChUqmO+++y7da1+dCD5kyBADXHdr1apV2voLFy40NWvWNL6+viYwMNA89NBD5tNPP73D36aIuBKbMQ6e4EREREQkG9NEcBEREREHqDSJSJawevXqdKcA+PdNRORuafeciGQJly9f5u+//77pch3pJiJ3S6VJRERExAHaPSciIiLiAJUmEREREQeoNImIiIg4QKVJRERExAEqTSIiIiIOUGkSERERcYBKk4iIiIgDVJpEREREHPD/ABnbMWMBx+IvAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "compute logp:\n",
      "   vocab_size    triton      torch\n",
      "0     32000.0  0.514511   4.622731\n",
      "1     64000.0  1.024138   8.693873\n",
      "2     96000.0  1.477454  13.541513\n",
      "3    128000.0  2.022543  18.250278\n",
      "4    160000.0  2.582751  22.690176\n",
      "5    192000.0  2.950109  27.426208\n"
     ]
    }
   ],
   "source": [
    "\n",
    "@triton.testing.perf_report(\n",
    "    triton.testing.Benchmark(\n",
    "        x_names=['vocab_size'],  # argument names to use as an x-axis for the plot\n",
    "        x_vals=[32000 + 16000 * i for i in range(0, 11, 2)],  # different possible values for `x_name`\n",
    "        line_arg='provider',  # argument name whose value corresponds to a different line in the plot\n",
    "        line_vals=['triton', \"torch\"],  # possible values for `line_arg``\n",
    "        line_names=[\n",
    "            \"triton\",\n",
    "            \"torch\",\n",
    "        ],  # label name for the lines\n",
    "        styles=[('blue', '-'), ('green', '-'), ('orange', '-')],  # line styles\n",
    "        ylabel=\"ms\",  # label name for the y-axis\n",
    "        plot_name=\"compute logp\",  # name for the plot. Used also as a file name for saving the plot.\n",
    "        args={'L': 2048, 'B': 16}\n",
    "    ))\n",
    "def benchmark(B, L, vocab_size, provider):\n",
    "    device = \"cuda\"\n",
    "    dtype = torch.bfloat16\n",
    "    logits = torch.randn(B, L + 1, vocab_size, device=device, dtype=dtype)\n",
    "    input_ids = torch.randint(0, vocab_size-1, (B, L + 100), dtype=torch.int64, device=device)\n",
    "    mask = torch.ones(B, L+100, dtype=torch.int64, device=device)\n",
    "    stream = torch.cuda.Stream()\n",
    "    torch.cuda.set_stream(stream)\n",
    "    if provider == 'torch':\n",
    "        ms = triton.testing.do_bench(lambda: selective_log_softmax(logits, input_ids))\n",
    "    if provider == 'triton':\n",
    "        ms = triton.testing.do_bench(lambda: fused_selective_log_softmax(logits, input_ids, mask=mask))\n",
    "    return ms\n",
    "benchmark.run(show_plots=True, print_data=True)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# grpo loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## torch code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# 代码是根据trl仓库改的，因此triton实现也是根据这个仓库的实现方式进行改进的\n",
    "# 最主要的就是p(x)和p_old(x)是一样的\n",
    "def get_log_probs(logits, input_ids):\n",
    "    per_token_logps = []\n",
    "    for logits_row, input_ids_row in zip(logits, input_ids[:, -logits.size(1):]):\n",
    "        log_probs = logits_row.log_softmax(dim=-1)\n",
    "        token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)\n",
    "        per_token_logps.append(token_log_prob)\n",
    "    return torch.stack(per_token_logps)\n",
    "\n",
    "def torch_grpo_loss(logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, temperature, beta, eps_low, eps_high):\n",
    "    assert logits.is_contiguous() and completion_ids.is_contiguous()\n",
    "    assert old_logp is None or old_logp.is_contiguous()\n",
    "    assert (ref_logp is not None and ref_logp.is_contiguous()) if beta != 0.0 else True\n",
    "    logits = logits[:, :-1] # 错一位，对应下一个输入token的概率         \n",
    "    per_token_logps = get_log_probs(logits / temperature, completion_ids) # logits是需要计算梯度，因此会保存中间结果log_probs\n",
    "    # return per_token_logps, None, None\n",
    "    ref_per_token_logps = ref_logp\n",
    "\n",
    "    if old_logp is None:\n",
    "        old_logp = per_token_logps.detach()\n",
    "    coef_1 = torch.exp(per_token_logps - old_logp)\n",
    "    coef_2 = torch.clamp(coef_1, 1 - eps_low, 1 + eps_high)\n",
    "    per_token_loss1 = coef_1 * advantages.unsqueeze(1)\n",
    "    per_token_loss2 = coef_2 * advantages.unsqueeze(1)\n",
    "    per_token_loss = -torch.min(per_token_loss1, per_token_loss2)\n",
    "    per_token_loss = per_token_loss * completion_mask if completion_mask is not None else per_token_loss\n",
    "\n",
    "    per_token_kl = None\n",
    "    if beta != 0.0:\n",
    "        per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1\n",
    "        if completion_mask is not None:\n",
    "            per_token_kl *= completion_mask\n",
    "        per_token_loss = per_token_loss + beta * per_token_kl\n",
    "    is_clipped = (per_token_loss1 < per_token_loss2).float()\n",
    "    return per_token_loss, per_token_kl, is_clipped"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## triton code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# @triton.autotune([triton.Config({\"BLOCK_N\":BLOCK_N}, num_stages=ns, num_warps=nw)\n",
    "#                   for BLOCK_N in [2048, 4096, 8192]\n",
    "#                   for ns in [1, 2, 4]\n",
    "#                   for nw in [1, 2, 4, 8, 16]],\n",
    "#                   key=['N'])\n",
    "@triton.jit\n",
    "def _grpo_loss_fwd_kernel(LOGITS,\n",
    "                         OLD_LOGP,\n",
    "                         REF_LOGP,\n",
    "                        INPUT_IDS,\n",
    "                        COMPLETION_MASK,\n",
    "                        ADVANTAGES,\n",
    "                        LOSS,\n",
    "                        LSE,\n",
    "                        KL,\n",
    "                        IS_CLIPPED,\n",
    "                        TEMPERATURE,\n",
    "                        BETA:tl.constexpr,\n",
    "                        EPS_LOW,\n",
    "                        EPS_HIGH,\n",
    "                        L: tl.constexpr,\n",
    "                        N: tl.constexpr,\n",
    "                        BLOCK_N:tl.constexpr=4096):\n",
    "    off_b = tl.program_id(0).cast(tl.int64)\n",
    "    off_l = tl.program_id(1).cast(tl.int64)\n",
    "\n",
    "    if COMPLETION_MASK is not None:\n",
    "        COMPLETION_MASK += off_b * L + off_l\n",
    "        not_skip = tl.load(COMPLETION_MASK)\n",
    "        if not_skip == 0:\n",
    "            return\n",
    "        \n",
    "    LOGITS += off_b * (L+1) * N + off_l * N\n",
    "    INPUT_IDS += off_b * L + off_l\n",
    "    ADVANTAGES += off_b\n",
    "    LOSS += off_b * L + off_l\n",
    "    LSE += off_b * L + off_l\n",
    "    IS_CLIPPED += off_b * L + off_l\n",
    "        \n",
    "    m_i = float('-inf')\n",
    "    l_i = 0. \n",
    "    for start in range(0, N, BLOCK_N):\n",
    "        cols = start + tl.arange(0, BLOCK_N)\n",
    "        logits = tl.load(LOGITS + cols, mask=cols < N, other=float('-inf')).to(tl.float32) / TEMPERATURE\n",
    "        new_m_i = tl.maximum(m_i, tl.max(logits))\n",
    "        alpha = tl.exp(m_i - new_m_i)\n",
    "        l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))\n",
    "        m_i = new_m_i\n",
    "    lse = m_i + tl.log(l_i)\n",
    "\n",
    "    idx = tl.load(INPUT_IDS)\n",
    "    x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE\n",
    "    logp = x - lse\n",
    "    if OLD_LOGP is None:\n",
    "        old_logp = logp\n",
    "    else:\n",
    "        OLD_LOGP += off_b * L + off_l\n",
    "        old_logp = tl.load(OLD_LOGP).to(tl.float32)\n",
    "    coef_1 = tl.exp(logp - old_logp)\n",
    "    coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)\n",
    "    advantage = tl.load(ADVANTAGES).to(tl.float32)\n",
    "    per_token_loss1 = coef_1 * advantage\n",
    "    per_token_loss2 = coef_2 * advantage\n",
    "    per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)\n",
    "    is_clipped = per_token_loss1 < per_token_loss2\n",
    "\n",
    "    if BETA != 0.0:\n",
    "        REF_LOGP += off_b * L + off_l\n",
    "        KL += off_b * L + off_l\n",
    "        ref_logp = tl.load(REF_LOGP).to(tl.float32)\n",
    "        kl = tl.exp(ref_logp - logp) - (ref_logp - logp) - 1\n",
    "        per_token_loss += BETA * kl\n",
    "        tl.store(KL, kl)\n",
    "        \n",
    "    tl.store(LOSS, per_token_loss)\n",
    "    tl.store(LSE, lse)\n",
    "    tl.store(IS_CLIPPED, is_clipped)\n",
    "    \n",
    "\n",
    "# @triton.autotune([triton.Config({\"BLOCK_N\":BLOCK_N}, num_stages=ns, num_warps=nw)\n",
    "#                   for BLOCK_N in [2048, 4096, 8192]\n",
    "#                   for ns in [1, 2, 4]\n",
    "#                   for nw in [1, 2, 4, 8, 16]],\n",
    "#                   key=['N'])  \n",
    "@triton.jit\n",
    "def _grpo_loss_bwd_kernel(DLOSS,\n",
    "                        DLOGITS,\n",
    "                        LOGITS,\n",
    "                         OLD_LOGP,\n",
    "                         REF_LOGP,\n",
    "                        INPUT_IDS,\n",
    "                        ADVANTAGES,\n",
    "                        COMPLETION_MASK,\n",
    "                        LSE,\n",
    "                        TEMPERATURE,\n",
    "                        BETA:tl.constexpr,\n",
    "                        EPS_LOW,\n",
    "                        EPS_HIGH,\n",
    "                        loss_stride0,\n",
    "                        loss_stride1,\n",
    "                        L: tl.constexpr,\n",
    "                        N: tl.constexpr,\n",
    "                        BLOCK_N:tl.constexpr=4096):\n",
    "\n",
    "    off_b = tl.program_id(0).cast(tl.int64)\n",
    "    off_l = tl.program_id(1).cast(tl.int64)\n",
    "\n",
    "    \n",
    "    DLOGITS += off_b * (L+1) * N + off_l * N\n",
    "    if COMPLETION_MASK is not None:\n",
    "        COMPLETION_MASK += off_b * L + off_l\n",
    "        not_skip = tl.load(COMPLETION_MASK)\n",
    "        if not_skip == 0:\n",
    "            for start in range(0, N, BLOCK_N):\n",
    "                cols = tl.arange(0, BLOCK_N) + start\n",
    "                tl.store(DLOGITS+cols, 0., mask=cols<N)\n",
    "            return\n",
    "    \n",
    "    LOGITS += off_b * (L+1) * N + off_l * N\n",
    "    DLOSS += off_b * loss_stride0 + off_l * loss_stride1\n",
    "    INPUT_IDS += off_b * L + off_l\n",
    "    ADVANTAGES += off_b\n",
    "    LSE += off_b * L + off_l\n",
    "\n",
    "    dloss = tl.load(DLOSS).to(tl.float32)\n",
    "    lse = tl.load(LSE).to(tl.float32)\n",
    "\n",
    "    idx = tl.load(INPUT_IDS)\n",
    "    x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE\n",
    "    logp = x - lse\n",
    "    if OLD_LOGP is None:\n",
    "        old_logp = logp\n",
    "    else:\n",
    "        OLD_LOGP += off_b * L + off_l\n",
    "        old_logp = tl.load(OLD_LOGP).to(tl.float32)\n",
    "    coef_1 = tl.exp(logp - old_logp)\n",
    "    coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)\n",
    "    advantage = tl.load(ADVANTAGES).to(tl.float32)\n",
    "    per_token_loss1 = coef_1 * advantage\n",
    "    per_token_loss2 = coef_2 * advantage\n",
    "    mask = per_token_loss2 >= per_token_loss1\n",
    "\n",
    "    dlogp = -per_token_loss1 * mask\n",
    "    if BETA != 0.0:\n",
    "        REF_LOGP += off_b * L + off_l\n",
    "        ref_logp = tl.load(REF_LOGP).to(tl.float32)\n",
    "        dlogp += BETA * (1 - tl.exp(ref_logp - logp))\n",
    "    \n",
    "    # REF_LOGP += off_b * L + off_l\n",
    "    # ref_logp = tl.load(REF_LOGP).to(tl.float32)\n",
    "    # dlogp += BETA * (1 - tl.exp(ref_logp - logp))\n",
    "    \n",
    "    dlogp = dlogp * dloss / TEMPERATURE\n",
    "    # tl.store(REF_LOGP, dlogp)\n",
    "    tl.debug_barrier()\n",
    "    for start_n in tl.range(0, N, BLOCK_N):\n",
    "        cols = start_n + tl.arange(0, BLOCK_N)\n",
    "        logits = tl.load(LOGITS+cols, mask=cols < N, other=0.).to(tl.float32) / TEMPERATURE\n",
    "        probs = tl.exp(logits - lse)\n",
    "        dlogits = tl.where(cols==idx, 1-probs, -probs) * dlogp\n",
    "        tl.store(DLOGITS+cols, dlogits.to(tl.bfloat16), mask=cols < N)\n",
    "        \n",
    "\n",
    "class GrpoLoss(torch.autograd.Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, temperature, beta, eps_low, eps_high, inplace):\n",
    "        assert logits.is_contiguous() and completion_ids.is_contiguous()\n",
    "        assert old_logp is None or old_logp.is_contiguous()\n",
    "        assert (ref_logp is not None and ref_logp.is_contiguous()) if beta != 0.0 else True\n",
    "        \n",
    "        B, L_ADD_1, N = logits.shape\n",
    "        L = L_ADD_1 - 1\n",
    "\n",
    "        if completion_mask is not None:\n",
    "            assert completion_mask.is_contiguous()\n",
    "\n",
    "        loss = torch.zeros(B, L, device=logits.device, dtype=torch.float32)\n",
    "        lse = torch.zeros_like(loss)\n",
    "        is_clipped = torch.zeros_like(loss)\n",
    "        kl = torch.zeros_like(loss) if beta != 0.0 else None\n",
    "        kwargs = {\"BLOCK_N\":2048, \"num_stages\":2, \"num_warps\":1}\n",
    "        _grpo_loss_fwd_kernel[(B, L)](logits,\n",
    "                                     old_logp,\n",
    "                                     ref_logp,\n",
    "                                     completion_ids,\n",
    "                                     completion_mask,\n",
    "                                     advantages,\n",
    "                                     loss,\n",
    "                                     lse,\n",
    "                                     kl,\n",
    "                                     is_clipped,\n",
    "                                     temperature,\n",
    "                                     beta,\n",
    "                                     eps_low,\n",
    "                                     eps_high,\n",
    "                                     L,\n",
    "                                     N,\n",
    "                                     **kwargs\n",
    "                                     )\n",
    "        ctx.save_for_backward(logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse)\n",
    "        ctx.infos = (temperature, beta, eps_low, eps_high, inplace)\n",
    "        # return loss\n",
    "        return loss, kl, is_clipped\n",
    "    \n",
    "    @staticmethod\n",
    "    def backward(ctx, *args):\n",
    "        dloss = args[0]\n",
    "        # print(dloss.shape)\n",
    "        logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse = ctx.saved_tensors\n",
    "        temperature, beta, eps_low, eps_high, inplace = ctx.infos\n",
    "        B, L_ADD_1, N = logits.shape\n",
    "        L = L_ADD_1 - 1\n",
    "        # dlogits = logits if inplace else torch.empty_like(logits)\n",
    "        dlogits = torch.empty_like(logits)\n",
    "        kwargs = {\"BLOCK_N\":4096, \"num_stages\":1, \"num_warps\":16}\n",
    "        _grpo_loss_bwd_kernel[(B, L)](dloss,\n",
    "                                      dlogits,\n",
    "                                      logits,\n",
    "                                      old_logp,\n",
    "                                      ref_logp,\n",
    "                                      completion_ids,\n",
    "                                      advantages,\n",
    "                                      completion_mask,\n",
    "                                      lse,\n",
    "                                      temperature,\n",
    "                                      beta,\n",
    "                                      eps_low,\n",
    "                                      eps_high,\n",
    "                                      *dloss.stride(),\n",
    "                                      L,\n",
    "                                      N,\n",
    "                                      **kwargs\n",
    "                                        )\n",
    "        dlogits[:, -1, :] = 0\n",
    "        return dlogits, None,None,None,None,None,None,None,None,None,None\n",
    "\n",
    "def triton_grpo_loss(logits, \n",
    "                     old_logp, \n",
    "                     ref_logp, \n",
    "                     completion_ids, \n",
    "                     advantages, \n",
    "                     completion_mask=None, \n",
    "                     temperature=0.9, \n",
    "                     beta=0.04, \n",
    "                     eps_low=0.2, \n",
    "                     eps_high=0.4, \n",
    "                     inplace=True):\n",
    "    assert logits is not None and completion_ids is not None and advantages is not None, \"must provide logits、completion_ids and advantages\"\n",
    "\n",
    "    return GrpoLoss.apply(logits, \n",
    "                          old_logp, \n",
    "                          ref_logp, \n",
    "                          completion_ids, \n",
    "                          advantages, \n",
    "                          completion_mask, \n",
    "                          temperature, \n",
    "                          beta, \n",
    "                          eps_low, \n",
    "                          eps_high,\n",
    "                          inplace)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 精度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_seed(seed):\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed(seed)\n",
    "set_seed(40)\n",
    "\n",
    "vocab_size = 12800\n",
    "B, L = 8, 1024\n",
    "device = \"cuda\"\n",
    "dtype = torch.bfloat16\n",
    "logits1 = torch.randn(B, L + 1, vocab_size, device=device, dtype=dtype)\n",
    "logits1.requires_grad_(True)\n",
    "logits2 = deepcopy(logits1)\n",
    "gold_logits = logits1.detach().clone().float()\n",
    "gold_logits.requires_grad_(True)\n",
    "\n",
    "\n",
    "\n",
    "completion_ids = torch.randint(0, vocab_size-1, (B, L), dtype=torch.int64, device=device)\n",
    "completion_mask = torch.ones_like(completion_ids, dtype=torch.int32)\n",
    "# completion_mask[:, -200:] = 0\n",
    "completion_mask = None\n",
    "ref_logp = torch.randn(B, L, device=device, dtype=torch.float32)\n",
    "# ref_logp = None\n",
    "old_logp = torch.randn(B, L, device=device, dtype=torch.float32)\n",
    "# old_logp = None\n",
    "advantages = torch.randn(B, device=device, dtype=torch.float32)\n",
    "temperature, beta, eps_low, eps_high = 0.9, 0.2, 0.2, 0.4\n",
    "inplace = True\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch-bf16 vs torch-fp32\n",
      "最大差异: 0.03676421567797661, 平均差异: 0.015371764078736305\n",
      "最大差异: 0.036764197051525116, 平均差异: 0.015372702851891518\n",
      "最大差异: 0.07438357919454575, 平均差异: 0.02046402171254158\n",
      "triton-bf16 vs torch-fp32\n",
      "最大差异: 2.7585529096540995e-06, 平均差异: 4.938870006299112e-07\n",
      "最大差异: 2.741835942288162e-06, 平均差异: 4.928543830828858e-07\n",
      "最大差异: 0.0038913956377655268, 平均差异: 0.001404356211423874\n"
     ]
    }
   ],
   "source": [
    "loss1, kl1, is_clipped1 = torch_grpo_loss(logits1,\n",
    "                old_logp,\n",
    "                ref_logp,\n",
    "                completion_ids,\n",
    "                advantages,\n",
    "                completion_mask,\n",
    "                temperature,\n",
    "                beta,\n",
    "                eps_low,\n",
    "                eps_high)\n",
    "\n",
    "loss2, kl2, is_clipped2 = triton_grpo_loss(logits2,\n",
    "                old_logp,\n",
    "                ref_logp,\n",
    "                completion_ids,\n",
    "                advantages,\n",
    "                completion_mask,\n",
    "                temperature,\n",
    "                beta,\n",
    "                eps_low,\n",
    "                eps_high,\n",
    "                )\n",
    "\n",
    "gold_loss, gold_kl, gold_is_clipped = torch_grpo_loss(gold_logits,\n",
    "                old_logp,\n",
    "                ref_logp,\n",
    "                completion_ids,\n",
    "                advantages,\n",
    "                completion_mask,\n",
    "                temperature,\n",
    "                beta,\n",
    "                eps_low,\n",
    "                eps_high)\n",
    "dy = torch.randn_like(loss1)\n",
    "loss1.backward(dy)\n",
    "loss2.backward(dy)\n",
    "gold_loss.backward(dy)\n",
    "print(\"torch-bf16 vs torch-fp32\")\n",
    "compare(loss1, gold_loss)\n",
    "compare(kl1, gold_kl)\n",
    "compare(logits1.grad, gold_logits.grad)\n",
    "print(\"triton-bf16 vs torch-fp32\")\n",
    "compare(loss2, gold_loss)\n",
    "compare(kl2, gold_kl)\n",
    "compare(logits2.grad, gold_logits.grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 259,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(-17.4833, device='cuda:0', grad_fn=<MinBackward1>)"
      ]
     },
     "execution_count": 259,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(ref_logp.unsqueeze(-1) * logits1[:, :-1]).min()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 209,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 209,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.allclose(logits1, logits2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(inf, device='cuda:0', dtype=torch.bfloat16)"
      ]
     },
     "execution_count": 89,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "logits2.grad.max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.110487222671509\n",
      "0.34517961740493774\n"
     ]
    }
   ],
   "source": [
    "print(triton.testing.do_bench(lambda:torch_grpo_loss(logits1,\n",
    "                old_logp,\n",
    "                ref_logp,\n",
    "                completion_ids,\n",
    "                advantages,\n",
    "                completion_mask,\n",
    "                temperature,\n",
    "                beta,\n",
    "                eps_low,\n",
    "                eps_high)[0].sum().backward()))\n",
    "\n",
    "print(triton.testing.do_bench(lambda:triton_grpo_loss(logits2,\n",
    "                old_logp,\n",
    "                ref_logp,\n",
    "                completion_ids,\n",
    "                advantages,\n",
    "                completion_mask,\n",
    "                temperature,\n",
    "                beta,\n",
    "                eps_low,\n",
    "                eps_high)[0].sum().backward()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 速度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAk0AAAGxCAYAAAB/QoKnAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAUnJJREFUeJzt3Xd4FOXi9vHvpieEJBBIQXqRDtINIKAEgxQb1h9yUDG0IL2qVKWKSK8HFRXrOTaK9ICAFAGR3gQFgRBaElrqPu8fvOwhirKEwGyS+3Nde8HuzO7eT4Ddm5lnZmzGGIOIiIiI/CM3qwOIiIiI5AQqTSIiIiJOUGkSERERcYJKk4iIiIgTVJpEREREnKDSJCIiIuIElSYRERERJ6g0iYiIiDjBw+oAd5rdbufEiRPkz58fm81mdRwRERFxgjGGCxcuUKRIEdzcXGMbT64vTSdOnKBYsWJWxxAREZEsOHbsGEWLFrU6BpAHSlP+/PmBqz/0gIAAi9OIiIiIM5KSkihWrJjje9wV5PrSdG2XXEBAgEqTiIhIDuNKU2tcYyehiIiIiItTaRIRERFxgkqTiIiIiBNy/ZwmZ2VkZJCWlmZ1jFzD09MTd3d3q2OIiIhkmzxfmowxxMXFkZCQYHWUXCcoKIiwsDCXmsQnIiKSVXm+NF0rTCEhIfj5+ekLPhsYY7h8+TLx8fEAhIeHW5xIRETk9uXp0pSRkeEoTMHBwVbHyVV8fX0BiI+PJyQkRLvqREQkx8vTE8GvzWHy8/OzOEnudO3nqrliIiKSG+Tp0nSNdsndGfq5iohIbqLSJCIiIuIElaZcbtiwYdx3331WxxAREcnxVJpyqCZNmtCzZ8+brte3b19WrlzpuP/iiy/y+OOP37lgIiIiuZRKUy5ljCE9PR1/f38dGSgiInfEpdRLTN40GbuxWx3lrlBpyoFefPFF1qxZw6RJk7DZbNhsNj744ANsNhvff/89tWrVwtvbm3Xr1mXaPTds2DDmzZvHt99+63je6tWrAdi5cycPPfQQvr6+BAcH07FjRy5evJjpPR9//HHGjx9PeHg4wcHBxMTE6Mg4EZE8atGBRVSeXpkeS3owe+tsq+PcFXn6PE1/ZgxcvmzNe/v5gbMHm02aNIkDBw5QpUoVRowYAcDu3bsBGDhwIOPHj6d06dIUKFDAUYrg6q66vXv3kpSUxPvvvw9AwYIFuXTpElFRUURERPDTTz8RHx/PK6+8Qrdu3fjggw8cz4+NjSU8PJzY2FgOHTrEs88+y3333Ud0dHS2/AxERMT1nbhwgh5LevCfPf8BoHhgcUoElrA41d2h0nSdy5fB39+a9754EfLlc27dwMBAvLy88PPzIywsDIB9+/YBMGLECJo1a3bD5/n7++Pr60tKSorjeQDz5s0jOTmZDz/8kHz/P8TUqVNp3bo1Y8eOJTQ0FIACBQowdepU3N3dqVChAi1btmTlypUqTSIieUCGPYPpP03n9VWvcyH1Au42d3rd34thTYaRz8vJL7AcTqUpl6ldu/YtP2fv3r1Ur17dUZgAGjRogN1uZ//+/Y7SVLly5Uxn9g4PD2fnzp23H1pERFzatpPb6LSwE1tObAGg3j31mNVqFtXDqluc7O5SabqOn9/VLT5WvXd2yOfs5qos8PT0zHTfZrNht+eNyX8iInnRhZQLDIkdwuTNVyd7B3oHMrrpaDrW6oi7W967PJZK03VsNud3kVnNy8uLjIyMbHlexYoV+eCDD7h06ZKjdK1fvx43NzfKly+fLXlFRCRn+WbfN7z6/av8kfQHAM9WfpZ3o94lPH/evQi7jp7LoUqWLMmmTZv47bffOHPmjNNbfEqWLMmOHTvYv38/Z86cIS0tjbZt2+Lj40P79u3ZtWsXsbGxvPrqq7Rr186xa05ERPKGo4lHeeyzx3ji8yf4I+kPSgWVYknbJXz21Gd5ujCBSlOO1bdvX9zd3alUqRKFCxfm6NGjTj0vOjqa8uXLU7t2bQoXLsz69evx8/Nj6dKlnDt3jjp16vDUU0/RtGlTpk6deodHISIiriLdns6EDROoNK0S3+3/Dg83DwY1HMSurruIKhtldTyXYDPGGKtD3ElJSUkEBgaSmJhIQEBApmXJyckcOXKEUqVK4ePjY1HC3Es/XxGRnGHz8c10WtiJ7XHbAWhYvCEzW86kckhlyzL90/e3VTSnSUREJI9KTE7k9VWvM/2n6RgMBXwK8Hazt3mpxku42bQz6s9UmkRERPIYYwz/2fMfeizpwcmLJwFoV60d4x8eT0i+EIvTuS6VJhERkTzkyPkjxCyO4ftD3wNQrmA5ZrScQdPSTS1O5vpUmkRERPKAtIw0JmyYwPA1w7mSfgUvdy8GNhjIoAcG4eOheafOUGkSERHJ5X489iOdFnZiV/wuAJqUbMLMljMpX0jn4rsVKk0iIiK51Pkr5xm4YiCzt80GINg3mAlRE2hXrR02Z68SLw4qTSIiIrmMMYZPd31Kr6W9iL8UD8DL973MuGbjCPYLtjhdzqXSJCIikoscOneILou6sOLwCgAqFqrIzFYzaVSikcXJcj6VJvlHL774IgkJCXzzzTdWRxERkX+Qkp7CuPXjGLl2JCkZKXi7ezO40WD6NeiHl7uX1fFyBZWmHKpJkybcd999TJw40eooIiJisTW/raHzos7sO7MPgGalmzG95XTKFixrcbLcRaUpD0tNTcXLS//7EBHJqc5cPkP/5f15f/v7AITkC2Fi1ESeq/KcJnrfATpHeg704osvsmbNGiZNmoTNZsNms/Hbb7+xZs0a6tati7e3N+Hh4QwcOJD09HTH85o0aUK3bt3o2bMnhQoVIirq6gUYd+/eTatWrQgICCB//vw88MAD/Prrr5nec/z48YSHhxMcHExMTAxpaWl3dcwiIvI/xhg+2P4BFaZWcBSmTrU6sS9mH89XfV6F6Q7RlqYcaNKkSRw4cIAqVaowYsQIADIyMmjRogUvvvgiH374Ifv27SM6OhofHx+GDRvmeO68efPo0qUL69evB+D48eM0atSIJk2asGrVKgICAli/fn2mshUbG0t4eDixsbEcOnSIZ599lvvuu4/o6Oi7Om4REYF9Z/bReWFn1vy+BoAqIVWY1WoW9YvVtzhZ7qfSdB1jDJfTLlvy3n6efk7/zyAwMBAvLy/8/PwICwsD4PXXX6dYsWJMnToVm81GhQoVOHHiBAMGDGDIkCG4uV3dqFiuXDnGjRvneK3XXnuNwMBAPvvsMzw9PQG49957M71fgQIFmDp1Ku7u7lSoUIGWLVuycuVKlSYRkbsoOT2ZUWtHMWbdGNLsafh6+DKsyTB63d8LT3dPq+PlCSpN17mcdhn/0f6WvPfFQRfJ55Uvy8/fu3cvERERmYpXgwYNuHjxIn/88QfFixcHoFatWpmet337dh544AFHYbqRypUr4+7u7rgfHh7Ozp07s5xVRERuzYrDK+iyqAuHzh0CoEW5FkxrMY2SQSWtDZbHqDTlMfnyZS5mvr6+N33OnwuVzWbDbrdnay4REfmr+Evx9F7am/k75wMQ7h/O5Ecm06ZiG81bsoBK03X8PP24OOiiZe99K7y8vMjIyHDcr1ixIv/9738xxjj+Ia1fv578+fNTtGjRv32datWqMW/ePNLS0v5xa5OIiNw9dmNn7ra59F/Rn4TkBGzYiKkTw1sPvUWgT6DV8fIslabr2Gy229pFdjeVLFmSTZs28dtvv+Hv70/Xrl2ZOHEir776Kt26dWP//v0MHTqU3r17O+Yz3Ui3bt2YMmUKzz33HIMGDSIwMJCNGzdSt25dypfXhRxFRO62XfG76LywM+uPXT1gp0ZYDWa1mkWde+pYnEx0yoEcqm/fvri7u1OpUiUKFy5MWloaixcvZvPmzVSvXp3OnTvToUMH3njjjX98neDgYFatWsXFixdp3LgxtWrVYs6cOdrqJCJyl11Ou8ygFYOoMasG64+tJ59nPiY8PIHN0ZtVmFyEzRhjrA5xJyUlJREYGEhiYiIBAQGZliUnJ3PkyBFKlSqFj4+PRQlzL/18RUSc8/3B74lZHMORhCMAPF7hcSY3n0yxwGIWJ7POP31/W0W750RERCxy8sJJei7tyRe7vwCgWEAxpjwyhccqPGZxMrkRlSYREZG7LMOewcwtM3lt1WskpSThZnOjZ72eDH9wOP5e1pz6Rm5OpUlEROQu2h63nU4LO7H5+GYA6hSpw6xWs6gRXsPiZHIzKk0iIiJ3wcXUiwxbPYyJGyeSYTLI75WfUU1H0aV2F9zd3G/+AmI5lSYREZE7bMH+BcQsjuFY0jEAnq70NBObT6RI/iIWJ5NbodLE1WvOSfbTz1VE8ro/kv6g+/fd+Xrf1wCUDCrJtBbTaFGuhcXJJCvydGm6di6iy5cvO3U5Ebk1ly9fvfixzvkkInlNuj2dqZunMjh2MBdTL+Lh5kGfiD4MaTzklq8AIa4jT5cmd3d3goKCiI+PB8DPz0/X8skGxhguX75MfHw8QUFBmS72KyKS2205sYVOCzux7eQ2ACKKRjCr1Syqhla1OJncrjxdmgDCwsIAHMVJsk9QUJDj5ysiktslpSTxxqo3mPbTNOzGTpBPEGMjx/JKzVdws+kCHLlBni9NNpuN8PBwQkJCSEtLszpOruHp6aktTCKSJxhj+GrvV3Rf0p0TF04A8H9V/48JD08g1D/U4nSSnfJ8abrG3d1dX/IiInJLfkv4jW6Lu7Ho4CIAyhQow4yWM2hWppnFyeROUGkSERG5RWkZaUzcOJFha4ZxOe0ynm6eDGgwgNceeA1fTx1YlFupNImIiNyCjX9spNPCTuw4tQOARiUaMbPlTCoWrmhxMrnTVJpERESckJCcwKAVg5i1dRYGQ0HfgoxvNp4X73tRR17nESpNIiIi/8AYw+e7P6fnkp6cunQKgPbV2/N2s7cpnK+wxenkblJpEhER+Ru/nvuVrou7suzXZQCUDy7PzFYzaVKyibXBxBIqTSIiIn+SmpHK+B/H8+YPb5Kcnoy3uzevPfAaAxoMwNvD2+p4YhGVJhERkeus/X0tnRd1Zs/pPQA8VOohZrScwb3B91qcTKym0iQiIgKcu3KO/sv7M/fnuQAU9ivMhKgJtK3aVhO9BVBpEhGRPM4Yw0c7PqLPsj6cuXwGgOia0YyJHENB34IWpxNXotIkIiJ51oGzB+iyqAurjqwCoHLhysxsNZOGxRtanExckUqTiIjkOSnpKYxZN4ZR60aRmpGKj4cPQxoNoU/9Pni5e1kdT1yUpZddzsjIYPDgwZQqVQpfX1/KlCnDm2++iTHGsY4xhiFDhhAeHo6vry+RkZEcPHjQwtQiIpKTxR6JpdrMagxbM4zUjFSal23O7q67GfTAIBUm+UeWbmkaO3YsM2bMYN68eVSuXJktW7bw0ksvERgYSPfu3QEYN24ckydPZt68eZQqVYrBgwcTFRXFnj178PHxsTK+iIjkIKcvnabv8r58+MuHAIT5hzGp+SServS0JnqLU2zm+s06d1mrVq0IDQ1l7ty5jsfatGmDr68vH3/8McYYihQpQp8+fejbty8AiYmJhIaG8sEHH/Dcc8/d9D2SkpIIDAwkMTGRgICAOzYWERFxTXZj5/2f36f/iv6cu3IOGza61O7CyKYjCfIJsjqe/A1X/P62dPdc/fr1WblyJQcOHADgl19+Yd26dTzyyCMAHDlyhLi4OCIjIx3PCQwMpF69emzYsMGSzCIiknPsOb2HJh804ZUFr3Duyjmqh1ZnQ4cNTGs5TYVJbpmlu+cGDhxIUlISFSpUwN3dnYyMDEaOHEnbtm0BiIuLAyA0NDTT80JDQx3L/iwlJYWUlBTH/aSkpDuUXkREXNWVtCu89cNbvP3j26TZ0/Dz9GNEkxH0uL8HHm46BkqyxtK/OV988QXz58/nk08+oXLlymzfvp2ePXtSpEgR2rdvn6XXHD16NMOHD8/mpCIiklMs/3U5XRZ14dfzvwLQ+t7WTHlkCiWCSlicTHI6S3fP9evXj4EDB/Lcc89RtWpV2rVrR69evRg9ejQAYWFhAJw6dSrT806dOuVY9meDBg0iMTHRcTt27NidHYSIiLiE+EvxtP2qLQ9//DC/nv+Ve/Lfw1fPfMW3z32rwiTZwtLSdPnyZdzcMkdwd3fHbrcDUKpUKcLCwli5cqVjeVJSEps2bSIiIuKGr+nt7U1AQECmm4iI5F52Y+ff2/5NhakV+GTnJ9iw0b1ud/bE7OGJik/oyDjJNpbunmvdujUjR46kePHiVK5cmZ9//pkJEybw8ssvA2Cz2ejZsydvvfUW5cqVc5xyoEiRIjz++ONWRhcRERew5/QeOi3sxLqj6wCoEVaDWa1mUeeeOhYnk9zI0tI0ZcoUBg8eTNeuXYmPj6dIkSJ06tSJIUOGONbp378/ly5domPHjiQkJNCwYUOWLFmiczSJiORhV9KuMHLtSMatH0eaPY18nvl488E3ebXeq5roLXeMpedpuhtc8TwPIiKSdSsOr6Dzws6ZJnpPbTGV4oHFLU4m2ckVv79Vx0VEJEeIvxRP76W9mb9zPgBF8hdhyiNTeKKC5i3J3aHSJCIiLs1u7Lz383v0X96f88nnsWGjW91uvPXQWwR4u8YWCMkbVJpERMRl7Tm9h84LO7P26FoA7gu7j9mtZmuit1hCpUlERFzOlbQrjFo7irHrxzrO6P3mg2/SvV53TfQWy+hvnoiIuJQVh1fQZVEXDp07BECre1sx9ZGpOkGlWE6lSUREXMLpS6fpvaw3H+/4GNBEb3E9Kk0iImIpYwzvb3+ffsv7ce7KOWzYiKkTw8imIzXRW1yKSpOIiFhm7+m9dF7UmR9+/wGA6qHVmd16NnXvqWtxMpG/UmkSEZG7Ljk9mVFrRzFm3RjHRO8RTUbQ4/4emugtLkt/M0VE5K5aeXglXRZ14eC5gwC0LNeSaS2maaK3uDyVJhERuStOXzpNn2V9+GjHRwCE+4cz5ZEpPFnxSU30lhxBpUlERO6oG0307lqnKyMfGkmgT6DV8UScptIkIiJ3zL4z++i0sJNjone10GrMbjWbekXrWZxM5NapNImISLZLTk9m9NrRjF432jHRe3iT4fSo1wNPd0+r44lkiUqTiIhkq1VHVtF5YWfHRO8W5VowrcU0SgaVtDaYyG1SaRIRkWxx+tJp+i7vy4e/fAhcneg9+ZHJtKnYRhO9JVdQaRIRkdtijGHeL/Pou6wvZ6+cxYaNLrW7MKrpKE30llxFpUlERLJs35l9dF7YmTW/rwGuTvSe1WoW9xe93+JkItlPpUlERG5ZcnoyY9aNYfS60aRmpOLr4cvwJsPpeX9PTfSWXEulSUREbknskVg6L+rMgbMHAHik7CNMazGNUgVKWZxM5M5SaRIREaecuXyGvsv6Mu+XeQCE+Ycxuflknqr0lCZ6S56g0iQiIv/IGMOHv3xIn2V9HBO9O9fuzKimowjyCbI6nshdo9IkIiJ/a/+Z/XRe1JnVv60GoGpIVWa1mkVEsQhrg4lYQKVJRET+IiU9hTHrxjBq3SjHRO9hTYbR6/5emugteZZKk4iIZLL6t9V0XtiZ/Wf3A9C8bHOmt5iuid6S56k0iYgIcHWid7/l/fhg+wcAhOYLZVLzSTxT+RlN9BZBpUlEJM/780RvgM61OjM6crQmeotcR6VJRCQPO3D2AJ0Xdib2t1gAqoRUYXar2ZroLXIDKk0iInlQSnoKY9ePZeTakY6J3kMbD6V3RG9N9Bb5GypNIiJ5zJrf1tBpYSfHRO+oMlFMbzmd0gVKW5xMxLWpNImI5BFnL5+l3/J+vL/9fUATvUVulUqTiEguZ4zhox0f0WdZH85cPgNAp1qdGBM5RhO9RW6BSpOISC524OwBuizqwqojqwCoXLgys1vPpn6x+hYnE8l5VJpERHKhlPQUxq0fx8i1I0nJSMHHw8cx0dvL3cvqeCI5kkqTiEgu88PvP9BpYSf2ndkHwMNlHmZGyxma6C1ym1SaRERyibOXz9J/eX/e2/4eACH5QpjUfBLPVn5WE71FsoFKk4hIDmeM4eMdH9N7WW/HRO+ONTsyJnIMBXwLWJxOJPdQaRIRycEOnj1Il0VdWHlkJXB1ovesVrNoULyBxclEch+VJhGRHCglPYW3f3ybt354yzHRe0ijIfSp30cTvUXuEJUmEZEcZu3va+m0sBN7z+wFoFnpZsxoOYMyBctYnEwkd1NpEhHJIc5dOUf/5f2Z+/Nc4OpE74lRE3muynOa6C1yF6g0iYi4OGMM83fOp/fS3py+fBqA6JrRjI0cq4neIneRSpOIiAv780TvSoUrMavVLBoWb2hxMpG8R6VJRMQFpWak8vb6t3nzhzcdE70HNxpM3/p9NdFbxCIqTSIiLmbd0XV0XNAx00Tv6S2nU7ZgWYuTieRtKk0iIi7i3JVzDFg+gH///G8ACvsVZmLziTxf5XlN9BZxASpNIiIWM8bwyc5P6LW0l2Oi9ys1XmFss7EU9C1ocToRuUalSUTEQr+e+5Uui7qw/PByACoWqsisVrN4oMQDFicTkT9TaRIRsUBqRirjfxzPmz+8SXJ6Mt7u3gxuNJh+DfpporeIi1JpEhG5y9YdXUenhZ3Yc3oPAE1LNWVmq5ma6C3i4lSaRETukvNXzjNgxQDmbJsDQCG/Qrwb9S5tq7bVRG+RHEClSUTkDjPG8OWeL+n+fXdOXToFQIcaHRgbOZZgv2CL04mIs1SaRETuoN8TfidmcQyLDi4CoHxweWa3nk2jEo0sTiYit0qlSUTkDsiwZzBl8xTeWPUGl9Iu4enmyWsPvMaghoPw9vC2Op6IZIFKk4hINvv55M9EL4hm68mtADQs3pDZrWZTsXBFi5OJyO1QaRIRySaXUi8xbPUw3t34Lhkmg0DvQN5u9jYdanbAzeZmdTwRuU0qTSIi2WDpoaV0XtSZ3xJ+A+CZys8wMWoi4fnDrQ0mItlGpUlE5DbEX4qn19JefLLzEwCKBRRjesvptLq3lcXJRCS7qTSJiGSBMYYPtn9An2V9OJ98HjebG93rdufNh97E38vf6ngicgeoNImI3KIDZw/QaWEnVv+2GoDqodWZ03oOde6pY20wEbmjVJpERJyUmpHK2+vf5s0f3iQlIwVfD1+GNxlOz/t74unuaXU8EbnDLD+c4/jx47zwwgsEBwfj6+tL1apV2bJli2O5MYYhQ4YQHh6Or68vkZGRHDx40MLEIpIX/XjsR2rOqskbsW+QkpFCVJkodnfdTb8G/VSYRPIIS0vT+fPnadCgAZ6ennz//ffs2bOHd955hwIFCjjWGTduHJMnT2bmzJls2rSJfPnyERUVRXJysoXJRSSvSExOpOuirjR8ryG7T++msF9h5j85n+/bfk+pAqWsjicid5HNGGOsevOBAweyfv161q5de8PlxhiKFClCnz596Nu3LwCJiYmEhobywQcf8Nxzz930PZKSkggMDCQxMZGAgIBszS8iuZcxhq/3fU23xd04efEkAC/d9xJvN3tb14sTuQtc8fvb0i1N3333HbVr1+bpp58mJCSEGjVqMGfOHMfyI0eOEBcXR2RkpOOxwMBA6tWrx4YNG274mikpKSQlJWW6iYjcij+S/uCJz5+gzRdtOHnxJGULlmXlv1by3mPvqTCJ5GGWlqbDhw8zY8YMypUrx9KlS+nSpQvdu3dn3rx5AMTFxQEQGhqa6XmhoaGOZX82evRoAgMDHbdixYrd2UGISK6RYc9g6uapVJpWiW/3f4uHmwevP/A6Ozrv4KFSD1kdT0QsZunRc3a7ndq1azNq1CgAatSowa5du5g5cybt27fP0msOGjSI3r17O+4nJSWpOInITe04tYOOCzqy6fgmACKKRjC79WyqhFSxOJmIuApLtzSFh4dTqVKlTI9VrFiRo0ePAhAWFgbAqVOnMq1z6tQpx7I/8/b2JiAgINNNROTvXEm7wqAVg6g1uxabjm8iv1d+preYzrqX16kwiUgmlpamBg0asH///kyPHThwgBIlSgBQqlQpwsLCWLlypWN5UlISmzZtIiIi4q5mFZHcZ8XhFVSdUZUx68eQbk/nyYpPsjdmL13qdNEFdkXkLyzdPderVy/q16/PqFGjeOaZZ9i8eTOzZ89m9uzZANhsNnr27Mlbb71FuXLlKFWqFIMHD6ZIkSI8/vjjVkYXkRzszOUz9FnWhw9/+RCAe/Lfw9QWU3m8wuPWBhMRl2ZpaapTpw5ff/01gwYNYsSIEZQqVYqJEyfStm1bxzr9+/fn0qVLdOzYkYSEBBo2bMiSJUvw8fGxMLmI5ETGGD7e8TG9lvbi7JWz2LARUyeGkU1HEuCtXfki8s8sPU/T3eCK53kQkbvv13O/0nlRZ1YcXgFAlZAqzGk9h/uL3m9xMhG5EVf8/ta150QkV0vLSGPChgkMWzOM5PRkvN29Gdp4KH3r99XlT0Tklqg0iUiutfn4ZqIXRLPj1A4AHir1EDNbzqRccDmLk4lITqTSJCK5zoWUC7yx6g2mbJ6CwRDsG8yEqAm0q9YOm81mdTwRyaFUmkQkV/lu/3fELI7hj6Q/AGhXrR3vPPwOhfMVtjiZiOR0Kk0ikiucvHCSV79/lf/u/S8ApQuUZmbLmTQr08ziZCKSW6g0iUiOZjd2Zm+dzcAVA0lMScTd5k7f+n0Z0ngIfp5+VscTkVxEpUlEcqzd8bvpuLAjPx77EYA6Reowp/UcqodVtziZiORGKk0ikuMkpyczau0oxqwbQ5o9DX8vf0Y+NJKYOjG4u7lbHU9EcimVJhHJUVb/tppOCztx4OwBAFrf25ppLaZRLLCYxclEJLdTaRKRHOHclXP0X96fuT/PBSDMP4ypj0zlyYpP6jQCInJXqDSJiEszxvD57s/psaQH8ZfiAehcqzOjI0cT5BNkbTgRyVNUmkTEZf2W8BtdFnVhyaElAFQsVJHZrWfTsHhDi5OJSF6k0iQiLifdns6kjZMYsnoIl9Mu4+XuxRsPvEH/Bv3x9vC2Op6I5FEqTSLiUrae2Er0gmh+jvsZgMYlGjOr1SzKFypvcTIRyetUmkTEJVxMvcjQ2KFM3DQRu7FTwKcA4x8ez0v3vaSJ3iLiElSaRMRyiw8upuuirvye+DsAz1d5nnej3iXUP9TiZCIi/6PSJCKWOXXxFD2X9uSzXZ8BUCKwBDNazuCRco9YnExE5K9UmkTkrjPG8N7P79F3eV8SkhNws7nR6/5eDG8ynHxe+ayOJyJyQypNInJX7T+zn04LO7Hm9zUA1AyvyZzWc6gZXtPiZCIi/0ylSUTuipT0FMauH8vItSNJzUjFz9OPNx98k+71uuPhpo8iEXF9+qQSkTtu3dF1dFzQkb1n9gLwSNlHmN5yOiWDSlobTETkFqg0icgdk5CcwMAVA5m1dRYAIflCmNx8Ms9UfkanERCRHEelSUSynTGG/+79L69+/ypxF+MAeKXGK4xtNpaCvgUtTicikjUqTSKSrY4lHiNmcQwLDiwA4N7ge5ndajaNSza2OJmIyO1RaRKRbJFhz2Dq5qm8EfsGF1Mv4unmyaCGgxj0wCB8PHysjicicttUmkTktv0S9wvRC6L56cRPADQo1oDZrWdTqXAli5OJiGQflSYRybLLaZcZsWYE438cT4bJINA7kLGRY4muFY2bzc3qeCIi2UqlSUSyZNmvy+i8sDNHEo4A8HSlp5nUfBLh+cMtTiYicmdk6b+C8+bNY9GiRY77/fv3JygoiPr16/P7779nWzgRcT2nL52m3dftiPo4iiMJRygaUJTvnvuOL57+QoVJRHK1LJWmUaNG4evrC8CGDRuYNm0a48aNo1ChQvTq1StbA4qIazDGMG/7PCpMq8DHOz7Gho0e9Xqwp+seWpdvbXU8EZE7Lku7544dO0bZsmUB+Oabb2jTpg0dO3akQYMGNGnSJDvziYgLOHTuEJ0WdmLVkVUAVA+tzuzWs6l7T12Lk4mI3D1Z2tLk7+/P2bNnAVi2bBnNmjUDwMfHhytXrmRfOhGxVFpGGqPXjqbqjKqsOrIKXw9fxkaO5afon1SYRCTPydKWpmbNmvHKK69Qo0YNDhw4QIsWLQDYvXs3JUqUyNaAImKNjX9sJHpBNLvidwHQrHQzZraaSekCpS1OJiJijSxtaZo2bRoRERGcPn2a//73vwQHBwOwdetW/u///i9bA4rI3ZWUkkS3xd2oP7c+u+J3UcivEB8/8TFLX1iqwiQieZrNGGOy8sTk5GR27NhBfHw8drs907JHH300W8Jlh6SkJAIDA0lMTCQgIMDqOCIu7Zt939BtcTeOXzgOQPvq7Rn/8HgK+RWyOJmI5DWu+P2dpd1zS5Ys4V//+hdnz57lz53LZrORkZGRLeFE5O44fP4wfZb14Zt93wBQtmBZZracSdPSTa0NJiLiQrK0e+7VV1/l6aef5sSJE9jt9kw3FSaRnCMhOYF+y/pRcVpFvtn3DR5uHrzW8DV2dN6hwiQi8idZ2tJ06tQpevfuTWhoaHbnEZG7IC0jjZlbZjJ8zXDOXrl6JGyz0s2YEDWBKiFVLE4nIuKaslSannrqKVavXk2ZMmWyO4+I3EHGGBYcWEC/5f04cPYAAJUKV2J8s/E0L9scm81mcUIREdeVpYngly9f5umnn6Zw4cJUrVoVT0/PTMu7d++ebQFvlytOJBOxwraT2+izrA+rf1sNQGG/wrz54Jt0qNkBDzddhlJEXIsrfn9n6ZPy008/ZdmyZfj4+LB69epM/zu12WwuVZpE8rrjScd5fdXrfPjLhxgM3u7e9I7ozcCGAwnwdo0PIhGRnCBLpen1119n+PDhDBw4EDe3LM0lF5E77GLqRcatH8f4H8dzJf3qmfr/r+r/MeqhUZQI0kloRURuVZZKU2pqKs8++6wKk4gLyrBn8MH2D3gj9g3iLsYB0LB4Q955+B1d+kRE5DZkqfW0b9+ezz//PLuziMhtWnF4BTVn1+SVBa8QdzGOMgXK8J+n/8MPL/6gwiQicpuytKUpIyODcePGsXTpUqpVq/aXieATJkzIlnAi4pw9p/fQb3k/Fh9cDECQTxCDGw0mpk4M3h7eFqcTEckdslSadu7cSY0aNQDYtWtXpmU6ZFnk7om/FM+w1cOYvXU2GSYDDzcPYurEMLjRYIL9gq2OJyKSq2SpNMXGxmZ3DhG5BcnpyUzcOJFRa0dxIfUCAI9XeJyxkWO5N/hei9OJiOROOjmLSA5ijOGzXZ8xaOUgfk/8HYBa4bV45+F3aFyyscXpRERyN5UmkRxi/dH19F7Wm83HNwNQNKAoox4aRdtqbXGz6UhWEZE7TaVJxMX9eu5XBq4cyH/2/AeAfJ75GNhwIL0jeuPn6WdxOhGRvEOlScRFnb9ynpFrRzJ502TS7Gm42dzoUKMDIx4cQZh/mNXxRETyHJUmEReTlpHGzC0zGbZmGOeunAPg4TIPM77ZeKqGVrU4nYhI3qXSJOIijDF8t/87+q/oz4GzBwCoXLgy4x8eT/OyzS1OJyIiKk0iLmDbyW30XtqbNb+vASAkXwgjmoygQ80OeLjpn6mIiCvQp7GIhf5I+oPXV73OR798hMHg4+FDr/t7MbDhQAK8A6yOJyIi11FpErHAxdSLjFs/jvE/judK+hUA2lZty6imoygeWNzidCIiciMqTSJ3UYY9g/e3v8/g2MHEXYwDoGHxhkx4eAJ17qljcToREfknKk0id8nyX5fTZ1kfdsbvBKBMgTKMazaOJyo8oWs2iojkACpNInfYntN76LusL98f+h6AAj4FGNxoMDF1Y/By97I4nYiIOEulSeQOib8Uz9DYoczZNocMk4GHmwfd6nRjcOPBFPQtaHU8ERG5RS5zwaoxY8Zgs9no2bOn47Hk5GRiYmIIDg7G39+fNm3acOrUKetCijjhStoVRq8dTdnJZZm5dSYZJoMnKjzBnq57eLf5uypMIiI5lEuUpp9++olZs2ZRrVq1TI/36tWLBQsW8OWXX7JmzRpOnDjBk08+aVFKkX9mN3Y+2fkJFaZV4LVVr3Eh9QK1wmux5sU1fPXsV5QLLmd1RBERuQ2Wl6aLFy/Stm1b5syZQ4ECBRyPJyYmMnfuXCZMmMBDDz1ErVq1eP/99/nxxx/ZuHGjhYlF/mr90fVEzI2g7VdtOZp4lKIBRfnoiY/YHL2ZRiUaWR1PRESygeWlKSYmhpYtWxIZGZnp8a1bt5KWlpbp8QoVKlC8eHE2bNhwt2OK3NCv537lqS+eouH7Ddl8fDP+Xv6MfGgkB7od4IVqL+Bms/yfmIiIZBNLJ4J/9tlnbNu2jZ9++ukvy+Li4vDy8iIoKCjT46GhocTFxf3ta6akpJCSkuK4n5SUlG15Ra45f+U8b/3wFlM2TyHNnoabzY1XarzC8AeHE+YfZnU8ERG5AywrTceOHaNHjx4sX74cHx+fbHvd0aNHM3z48Gx7PZHrpWakMuOnGYz4YQTnrpwDIKpMFOMfHk+VkCoWpxMRkTvJsn0HW7duJT4+npo1a+Lh4YGHhwdr1qxh8uTJeHh4EBoaSmpqKgkJCZmed+rUKcLC/v5/8oMGDSIxMdFxO3bs2B0eieQFxhi+2fcNVaZXoefSnpy7co7KhSvzfdvvWfLCEhUmEZE8wLItTU2bNmXnzp2ZHnvppZeoUKECAwYMoFixYnh6erJy5UratGkDwP79+zl69CgRERF/+7re3t54e3vf0eySt2w9sZU+y/qw5vc1AITkC+HNB9/k5Rov4+GmU52JiOQVln3i58+fnypVMv/vPF++fAQHBzse79ChA71796ZgwYIEBATw6quvEhERwf33329FZMljjiUe4/VVr/PRjo8A8PHwoff9vRnYcCD5vfNbnE5ERO42l/5v8rvvvoubmxtt2rQhJSWFqKgopk+fbnUsyeUupFxg3PpxjN8wnuT0ZABeqPYCIx8aSfHA4hanExERq9iMMcbqEHdSUlISgYGBJCYmEhAQYHUccWEZ9gze+/k9BscO5tSlq2eef6D4A7zz8DvUuaeOxelERPIWV/z+duktTSJ3y7Jfl9F3WV92xl+dZ1e2YFnGRo7liQpPYLPZLE4nIiKuQKVJ8rTd8bvpu7wvSw4tAaCATwGGNB5C1zpd8XL3sjidiIi4EpUmyZNOXTzF0NVDmbNtDnZjx9PNk5g6MQxuPFgX1BURkRtSaZI85UraFSZunMjodaO5kHoBgCcrPsnYyLGULVjW4nQiIuLKVJokT7AbO5/u/JTXVr3G0cSjANQuUpt3Hn5HF9QVERGnqDRJrrf297X0WdaHn05cvcZhsYBijG46muerPq8L6oqIiNNUmiTXOnTuEANWDOCrvV8B4O/lz6CGg+h1fy98PX0tTiciIjmNSpPkOuevnOfNH95k6uappNnTcLO58UqNVxjx4AhC/UOtjiciIjmUSpPkGqkZqUz/aToj1ozgfPJ5AJqXbc7bzd7WBXVFROS2qTRJjmeM4Zt939B/RX8OnTsEQJWQKoxvNp6oslEWpxMRkdxCpUlytC0nttBnWR9++P0HAELyhfDWg2/xUo2X8HDTX28REck++laRHOlY4jFeW/UaH+/4GAAfDx/6RPRhQIMB5PfOb3E6ERHJjVSaJEe5kHKBsevH8s6Gd0hOTwbghWovMOqhURQLLGZxOhERyc1UmiRHSLen897P7zEkdginLp0CoFGJRrzz8DvULlLb4nQiIpIXqDSJy1v26zL6LOvDrvhdAJQtWJa3m73NY+Ufw2azWZxORETyCpUmcVmnLp7i1e9f5cs9XwJQwKcAQxsPpUudLni5e1mcTkRE8hqVJnE5xhg+3vExPZf25NyVc7jb3Hm17qsMbjyYgr4FrY4nIiJ5lEqTuJSjiUfpvLAz3x/6HoD7wu5j7qNzqRle0+JkIiKS16k0iUuwGzuztsyi/4r+XEy9iJe7F0MbD6Vf/X54untaHU9ERESlSax34OwBXvnuFdYeXQtA/WL1mfvoXCoUqmBxMhERkf9RaRLLpNvTmbBhAkNXDyU5PZl8nvkY3XQ0MXVjcLO5WR1PREQkE5UmscQvcb/Q4bsObD25FYBmpZsxu/VsSgaVtDaYiIjI31BpkrsqJT2Ft354izHrx5BuTyfIJ4h3o96lffX2OueSiIi4NJUmuWs2HNtAh+86sPfMXgCeqPAE01pMIzx/uMXJREREbk6lSe64S6mXeH3V60zeNBmDITRfKNNaTKNNpTZWRxMREXGaSpPcUSsPryR6QTRHEo4A8K/q/+LdqHd1kkoREclxVJrkjkhITqDvsr7M/XkuAMUDizOr1Syal21ucTIREZGsUWmSbPftvm/psqgLJy+eBCCmTgyjm44mv3d+i5OJiIhknUqTZJv4S/F0/747n+/+HIByBcsx99G5PFDiAYuTiYiI3D6VJrltxhjm75xPjyU9HBfY7Ve/H0MaD8HX09fqeCIiItlCpUluy7HEY3Re1JnFBxcDUD20OnMfnUutIrUsTiYiIpK9VJokS+zGzuyts+m/vD8XUi/g5e7FkEZD6N+gvy6wKyIiuZJKk9yyg2cPEr0gmjW/rwEgomgEcx+dS8XCFS1OJiIicueoNInT0u3pvLvhXYasHkJyejJ+nn5XL7BbJwZ3N3er44mIiNxRKk3ilB2ndtDhuw5sObEFgMjSkcxuNZtSBUpZnExEROTuUGmSf5SSnsKotaMYtW4U6fZ0Ar0DmRA1gZfue0kX2BURkTxFpUn+1sY/NtLhuw7sOb0HgMcrPM60FtMokr+IxclERETuPpUm+YtLqZcYHDuYiRsnYjCE5Ath6iNTearSU9q6JCIieZZKk2Sy6sgqohdEc/j8YQDaVWvHu1HvEuwXbHEyERERa6k0CXD1Arv9lvXj3z//G4BiAcWY1WoWj5R7xOJkIiIirkGlSfhu/3d0WdSFExdOANC1dldGR44mwDvA4mQiIiKuQ6UpDzt96TTdl3Tns12fAVcvsPvvR/9NoxKNLE4mIiLielSa8iBjDJ/u+pTu33fn7JWzuNnc6BvRl2FNhukCuyIiIn9DpSmP+SPpDzov7Myig4sAqBZajbmPzqV2kdoWJxMREXFtKk15hN3YmbN1Dv2W93NcYHdwo8EMaDBAF9gVERFxgkpTHnDo3CGiF0Sz+rfVANxf9H7mPjqXSoUrWRtMREQkB1FpysXS7elM3DiRwbGDHRfYHfXQKLrV7aYL7IqIiNwilaZcauepnXT4rgM/nfgJgKalmjK79WxKFyhtcTIREZGcSaUpl0nNSL16gd21o0izpxHoHcg7D7/DyzVe1iVQREREboNKUy6y+fhmXv72ZXaf3g3AY+UfY3rL6brAroiISDZQacoFLqddZvCqwUzcNBG7sVPYrzBTW0zl6UpPa+uSiIhINlFpyuFij8TyyoJXHBfYfaHaC7wb9S6F/ApZnExERCR3UWnKoRKTE+m/vD+zt80GoGhAUWa2nEnLe1tanExERCR3UmnKgRYeWEjnhZ05fuE4AF1qd2FM5BhdYFdEROQOUmnKQU5fOk2PJT34dNenAJQtWJZ/t/43jUs2tjiZiIhI7qfSlAMYY/hs12d0X9KdM5fP4GZzo09EH4Y3Ga4L7IqIiNwlKk0u7njScTov6szCAwsBqBpSlfcee08X2BUREbnLVJpclDGGf2/7N32X9yUpJQlPN8+rF9htOAAvdy+r44mIiOQ5Kk0u6NdzvxK9IJrY32IBqHdPPeY+OpfKIZUtTiYiIpJ3qTS5kAx7BpM2TeKNVW9wJf0Kvh6+jHxoJN3rddcFdkVERCym0uQidsXvosN3Hdh8fDMAD5V6iDmt5+gCuyIiIi7Czco3Hz16NHXq1CF//vyEhITw+OOPs3///kzrJCcnExMTQ3BwMP7+/rRp04ZTp05ZlDj7pWakMnz1cGrOqsnm45sJ8A5gTus5rGi3QoVJRETEhVhamtasWUNMTAwbN25k+fLlpKWl8fDDD3Pp0iXHOr169WLBggV8+eWXrFmzhhMnTvDkk09amDr7/HT8J2rNrsWwNcNIs6fR+t7W7Om6h1dqvqJrxomIiLgYmzHGWB3imtOnTxMSEsKaNWto1KgRiYmJFC5cmE8++YSnnnoKgH379lGxYkU2bNjA/ffff9PXTEpKIjAwkMTERAICXOOM2ZfTLjMkdgjvbnzXcYHdKY9M4ZnKz6gsiYiI4Jrf3y41pykxMRGAggULArB161bS0tKIjIx0rFOhQgWKFy/+t6UpJSWFlJQUx/2kpKQ7nPrWrP5tNdELojl07hAAbau2ZWLzibrAroiIiIuzdPfc9ex2Oz179qRBgwZUqVIFgLi4OLy8vAgKCsq0bmhoKHFxcTd8ndGjRxMYGOi4FStW7E5Hd0piciKdF3bmwXkPcujcIe7Jfw8Ln1/Ix09+rMIkIiKSA7hMaYqJiWHXrl189tlnt/U6gwYNIjEx0XE7duxYNiXMukUHFlF5emVmbZ0FQKdandjddTct721pcTIRERFxlkvsnuvWrRsLFy7khx9+oGjRoo7Hw8LCSE1NJSEhIdPWplOnThEWFnbD1/L29sbb2/tOR3bKmctn6LmkJ/N3zgegTIEy/PvRf9OkZBNrg4mIiMgts3RLkzGGbt268fXXX7Nq1SpKlSqVaXmtWrXw9PRk5cqVjsf279/P0aNHiYiIuNtxnXbtArsVp1Vk/s75uNnc6BvRlx1ddqgwiYiI5FCWbmmKiYnhk08+4dtvvyV//vyOeUqBgYH4+voSGBhIhw4d6N27NwULFiQgIIBXX32ViIgIp46cs8LxpON0XdyV7/Z/B0CVkCrMfXQude+pa3EyERERuR2WnnLg7w6vf//993nxxReBqye37NOnD59++ikpKSlERUUxffr0v90992d365BFYwxzf55L32V9SUxJxNPNkzcavcHAhgN1gV0REZFb5IqnHHCp8zTdCXfjh374/GGiF0Sz6sgqAOreU5e5j86lSkiVO/J+IiIiuZ0rliaXmAieU2XYM5i8aTKvr3rdcYHdtx56ix71eugCuyIiIrmMSlMW7Y7fTYfvOrDp+CYAHiz5IHNaz6FMwTIWJxMREZE7QaUpi2IWx7Dp+CYCvAMY32y8rhcnIiKSy7nMyS1zmuktp/NkxSfZ3XU30bWiVZhERERyOU0EFxEREZfjit/f2tIkIiIi4gSVJhEREREnqDSJiIiIOEGlSURERMQJKk0iIiIiTlBpEhEREXGCSpOIiIiIE1SaRERERJyg0iQiIiLiBJUmERERESeoNImIiIg4QaVJRERExAkqTSIiIiJOUGkSERERcYJKk4iIiIgTVJpEREREnKDSJCIiIuIElSYRERERJ6g0iYiIiDhBpUlERETECSpNIiIiIk5QaRIRERFxgkqTiIiIiBNUmkREREScoNIkIiIi4gSVJhEREREnqDSJiIhIltntkJ5udYq7w8PqACIiImIdY+DSJUhIgPPnr/56/e1mjyUmwrx50K6dRQO4i1SaREREcjBjIDk566UnIQEyMm4vQ0LC7T0/p1BpEhERsVhq6q2Vnj8/npp6+xk8PKBAAQgK+t+v19/+6bECBW7//XMClSYREZHblJ5+dTeVMwXnRo9duXL7GdzcnC85N3rM1xdsttvPkZupNImISJ5nt0NSUtZ3cV28mD05AgOzXnz8/VV67jSVJhERyfGulZ7ExP9NTr6V4pOUdHVu0O3y97+1rTvX3wICwN399jPInaPSJCIilrs2kfnPpcfZXy9cyJ7S4+Nz67u1rt0CA8HT8/YziOtSaRIRkduSkXHjrTy38mt2TGQG8PLKXGL+XGz+qfwEBl4tTSJ/R6VJRCQPu/5w9dvZypMdbLaru6iuFZib/Xqjx1R65E5SaRIRycGubeW5ndKTlpY9WXx8nCs2f/dr/vxXjwATcVUqTSIiFjHm6qHmzhacGz2WXVt53NxubSvPjcqRt3f2ZBFxVSpNIiJZlJb296XGmd8nJGTfNbt8fbNWdK793t9fW3lEbkalSUTypOsPUc9q8cmOExLC/05KmNXdWoGBVydAi8idpdIkIjmOMXD58u0Vnuw6RB2ubqX5u8LjTBHKl08nJRTJCVSaROSuS029vd1aiYnZt1vL2/vWC8/1vw8IuHrNLhHJ/fRPXURuyfXn5Mlq8UlOzp4sbm4335LzT7/XIeoicitUmkTykJSUq4Xn2u3Chcz3nSk+2XW0Flw9xDwrW3e0W0tErKDSJOLi7Ha4dClzublR4XFmeXaddRmubqHJSuG5freWrrMlIjmJSpPIHZKa+vfF5lYKT3ZOWL4mX76rpeXaLX/+q7db2bWlc/KISF6j0iRynWtHZd3qFpwbLc+ueTvXuLtnLjp/vuXP79wyf39NXBYRyQp9dEqukJ5+8zLjbOGx27M3m5/fzUuNM8t9fDR/R0TESipNYim7/WpRuTbZ+O9uNys72XWSwWuuXVLiVkrNjZbnz6+tOiIiuYU+ziXL/nzo+a3crj0vu+fr+Pjc3taca8v9/LRVR0REMlNpyqPS02+94Pz5dvFi9uXx8sp87pwb3ZyZt+PpmX2ZRERErqfSlANdO5vyrRScP98uX86+PNcfep7Vm47EEhERV6fSdJclJ996wfnzLTuPyvLz+/utOs4WHl0oVERE8gKVpiw6eBB+++3WC092nlwwX77b27oTEKDdWSIiIs5Sacqi0aPh/fez/vzrLx9xKyXn+t/rqCwREZG7R1+7WVSqFFSpkrUtPPnz6/IRIiIiOY3NmOy+QEP2mzZtGm+//TZxcXFUr16dKVOmULduXaeem5SURGBgIImJiQQEBNzhpCIiIpIdXPH7283qADfz+eef07t3b4YOHcq2bduoXr06UVFRxMfHWx1NRERE8hCXL00TJkwgOjqal156iUqVKjFz5kz8/Px47733rI4mIiIieYhLl6bU1FS2bt1KZGSk4zE3NzciIyPZsGHDDZ+TkpJCUlJSppuIiIjI7XLp0nTmzBkyMjIIDQ3N9HhoaChxcXE3fM7o0aMJDAx03IoVK3Y3ooqIiEgu59KlKSsGDRpEYmKi43bs2DGrI4mIiEgu4NKnHChUqBDu7u6cOnUq0+OnTp0iLCzshs/x9vbGW9fkEBERkWzm0luavLy8qFWrFitXrnQ8ZrfbWblyJRERERYmExERkbzGpbc0AfTu3Zv27dtTu3Zt6taty8SJE7l06RIvvfSS1dFEREQkD3H50vTss89y+vRphgwZQlxcHPfddx9Lliz5y+RwERERkTspR5wR/Ha44hlFRURE5J+54ve3S89pEhEREXEVKk0iIiIiTlBpEhEREXGCSpOIiIiIE1z+6LnbdW2eu65BJyIiknNc+952pePVcn1punDhAoCuQSciIpIDnT17lsDAQKtjAHnglAN2u50TJ06QP39+bDab1XGckpSURLFixTh27JjLHGaZnXL7+CD3j1Hjy/ly+xhz+/gg948xMTGR4sWLc/78eYKCgqyOA+SBLU1ubm4ULVrU6hhZEhAQkCv/IVyT28cHuX+MGl/Ol9vHmNvHB7l/jG5urjP92nWSiIiIiLgwlSYRERERJ6g0uSBvb2+GDh2Kt7e31VHuiNw+Psj9Y9T4cr7cPsbcPj7I/WN0xfHl+ongIiIiItlBW5pEREREnKDSJCIiIuIElSYRERERJ6g03QFjxozBZrPRs2dPx2PJycnExMQQHByMv78/bdq04dSpU5med/ToUVq2bImfnx8hISH069eP9PT0TOusXr2amjVr4u3tTdmyZfnggw/+8v7Tpk2jZMmS+Pj4UK9ePTZv3nzbYzp+/DgvvPACwcHB+Pr6UrVqVbZs2eJYboxhyJAhhIeH4+vrS2RkJAcPHsz0GufOnaNt27YEBAQQFBREhw4duHjxYqZ1duzYwQMPPICPjw/FihVj3Lhxf8ny5ZdfUqFCBXx8fKhatSqLFy++7fFlZGQwePBgSpUqha+vL2XKlOHNN9/MdPr+nDTGH374gdatW1OkSBFsNhvffPNNpuWuNBZnstzqGNPS0hgwYABVq1YlX758FClShH/961+cOHEix4zxZn+G1+vcuTM2m42JEyfmqvHt3buXRx99lMDAQPLly0edOnU4evSoY7mrf67ebIwXL16kW7duFC1aFF9fXypVqsTMmTMzrePKYxw9ejR16tQhf/78hISE8Pjjj7N//36Xze9Mlpsykq02b95sSpYsaapVq2Z69OjheLxz586mWLFiZuXKlWbLli3m/vvvN/Xr13csT09PN1WqVDGRkZHm559/NosXLzaFChUygwYNcqxz+PBh4+fnZ3r37m327NljpkyZYtzd3c2SJUsc63z22WfGy8vLvPfee2b37t0mOjraBAUFmVOnTmV5TOfOnTMlSpQwL774otm0aZM5fPiwWbp0qTl06JBjnTFjxpjAwEDzzTffmF9++cU8+uijplSpUubKlSuOdZo3b26qV69uNm7caNauXWvKli1rnn/+ecfyxMREExoaatq2bWt27dplPv30U+Pr62tmzZrlWGf9+vXG3d3djBs3zuzZs8e88cYbxtPT0+zcuTPL4zPGmJEjR5rg4GCzcOFCc+TIEfPll18af39/M2nSpBw5xsWLF5vXX3/dfPXVVwYwX3/9dablrjQWZ7Lc6hgTEhJMZGSk+fzzz82+ffvMhg0bTN26dU2tWrUyvYYrj/Fmf4bXfPXVV6Z69eqmSJEi5t1338014zt06JApWLCg6devn9m2bZs5dOiQ+fbbbzN9lrn65+rNxhgdHW3KlCljYmNjzZEjR8ysWbOMu7u7+fbbb3PEGKOiosz7779vdu3aZbZv325atGhhihcvbi5evOiS+W+WxRkqTdnowoULply5cmb58uWmcePGjtKUkJBgPD09zZdffulYd+/evQYwGzZsMMZc/cfl5uZm4uLiHOvMmDHDBAQEmJSUFGOMMf379zeVK1fO9J7PPvusiYqKctyvW7euiYmJcdzPyMgwRYoUMaNHj87yuAYMGGAaNmz4t8vtdrsJCwszb7/9tuOxhIQE4+3tbT799FNjjDF79uwxgPnpp58c63z//ffGZrOZ48ePG2OMmT59uilQoIBjvNfeu3z58o77zzzzjGnZsmWm969Xr57p1KlTlsdnjDEtW7Y0L7/8cqbHnnzySdO2bdscP8Y/f1i70licyZKVMd7I5s2bDWB+//33HDfGvxvfH3/8Ye655x6za9cuU6JEiUylKaeP79lnnzUvvPDC3z4np32u3miMlStXNiNGjMj0WM2aNc3rr7+eI8cYHx9vALNmzRqXy+9MFmdo91w2iomJoWXLlkRGRmZ6fOvWraSlpWV6vEKFChQvXpwNGzYAsGHDBqpWrUpoaKhjnaioKJKSkti9e7djnT+/dlRUlOM1UlNT2bp1a6Z13NzciIyMdKyTFd999x21a9fm6aefJiQkhBo1ajBnzhzH8iNHjhAXF5fpfQMDA6lXr16m8QUFBVG7dm3HOpGRkbi5ubFp0ybHOo0aNcLLyyvT+Pbv38/58+ed+hlkVf369Vm5ciUHDhwA4JdffmHdunU88sgjuWaM17jSWJzJkl0SExOx2WyOa1jl9DHa7XbatWtHv379qFy58l+W5+Tx2e12Fi1axL333ktUVBQhISHUq1cv0+6tnP65Clc/d7777juOHz+OMYbY2FgOHDjAww8/nCPHmJiYCEDBggVdLr8zWZyh0pRNPvvsM7Zt28bo0aP/siwuLg4vL6+/XHAwNDSUuLg4xzrX/6W5tvzasn9aJykpiStXrnDmzBkyMjJuuM6118iKw4cPM2PGDMqVK8fSpUvp0qUL3bt3Z968eZny/dP7xsXFERISkmm5h4cHBQsWzJafwe2MD2DgwIE899xzVKhQAU9PT2rUqEHPnj1p27ZtrhnjNa40FmeyZIfk5GQGDBjA888/77hGV04f49ixY/Hw8KB79+43XJ6TxxcfH8/FixcZM2YMzZs3Z9myZTzxxBM8+eSTrFmzxvG+OflzFWDKlClUqlSJokWL4uXlRfPmzZk2bRqNGjXKcWO02+307NmTBg0aUKVKFZfL70wWZ+T6C/beDceOHaNHjx4sX74cHx8fq+NkO7vdTu3atRk1ahQANWrUYNeuXcycOZP27dtbnC57fPHFF8yfP59PPvmEypUrs337dnr27EmRIkVyzRjzqrS0NJ555hmMMcyYMcPqONli69atTJo0iW3btmGz2ayOk+3sdjsAjz32GL169QLgvvvu48cff2TmzJk0btzYynjZZsqUKWzcuJHvvvuOEiVK8MMPPxATE0ORIkX+smXF1cXExLBr1y7WrVtndZQ7SluassHWrVuJj4+nZs2aeHh44OHhwZo1a5g8eTIeHh6EhoaSmppKQkJCpuedOnWKsLAwAMLCwv4yi//a/ZutExAQgK+vL4UKFcLd3f2G61x7jawIDw+nUqVKmR6rWLGi4yiWa6/9T+8bFhZGfHx8puXp6emcO3cuW34GtzM+gH79+jm2NlWtWpV27drRq1cvx5bD3DDGa1xpLM5kuR3XCtPvv//O8uXLM10JPiePce3atcTHx1O8eHHHZ87vv/9Onz59KFmyZI4fX6FChfDw8Ljp505O/ly9cuUKr732GhMmTKB169ZUq1aNbt268eyzzzJ+/PgcNcZu3bqxcOFCYmNjKVq0qONxV8rvTBZnqDRlg6ZNm7Jz5062b9/uuNWuXZu2bds6fu/p6cnKlSsdz9m/fz9Hjx4lIiICgIiICHbu3JnpQ+7ah/y1D46IiIhMr3FtnWuv4eXlRa1atTKtY7fbWblypWOdrGjQoMFfDiM9cOAAJUqUAKBUqVKEhYVlet+kpCQ2bdqUaXwJCQls3brVsc6qVauw2+3Uq1fPsc4PP/xAWlpapvGVL1+eAgUKOPUzyKrLly/j5pb5n4O7u7vjf7y5YYzXuNJYnMmSVdcK08GDB1mxYgXBwcGZlufkMbZr144dO3Zk+swpUqQI/fr1Y+nSpTl+fF5eXtSpU+cfP3dq1aqVoz9X09LSSEtL+8fPHVcfozGGbt268fXXX7Nq1SpKlSqVabkr5Xcmi1OcnjIut+T6o+eMuXqoY/Hixc2qVavMli1bTEREhImIiHAsv3bY5cMPP2y2b99ulixZYgoXLnzDwy779etn9u7da6ZNm3bDwy69vb3NBx98YPbs2WM6duxogoKCMh2ZcKs2b95sPDw8zMiRI83BgwfN/PnzjZ+fn/n4448d64wZM8YEBQWZb7/91uzYscM89thjNzyEvUaNGmbTpk1m3bp1ply5cpkOf05ISDChoaGmXbt2ZteuXeazzz4zfn5+fzn82cPDw4wfP97s3bvXDB06NFtOOdC+fXtzzz33OE458NVXX5lChQqZ/v3758gxXrhwwfz888/m559/NoCZMGGC+fnnnx1HjrnSWJzJcqtjTE1NNY8++qgpWrSo2b59uzl58qTjdv2RYq48xpv9Gf7Zn4+ey+nj++qrr4ynp6eZPXu2OXjwoOMw87Vr1zpew9U/V282xsaNG5vKlSub2NhYc/jwYfP+++8bHx8fM3369Bwxxi5dupjAwECzevXqTP/GLl++7JL5b5bFGSpNd8ifS9OVK1dM165dTYECBYyfn5954oknzMmTJzM957fffjOPPPKI8fX1NYUKFTJ9+vQxaWlpmdaJjY019913n/Hy8jKlS5c277///l/ee8qUKaZ48eLGy8vL1K1b12zcuPG2x7NgwQJTpUoV4+3tbSpUqGBmz56dabndbjeDBw82oaGhxtvb2zRt2tTs378/0zpnz541zz//vPH39zcBAQHmpZdeMhcuXMi0zi+//GIaNmxovL29zT333GPGjBnzlyxffPGFuffee42Xl5epXLmyWbRo0W2PLykpyfTo0cMUL17c+Pj4mNKlS5vXX3890xdsThpjbGysAf5ya9++vcuNxZkstzrGI0eO3HAZYGJjY3PEGG/2Z/hnNypNOX18c+fONWXLljU+Pj6mevXq5ptvvsn0Gq7+uXqzMZ48edK8+OKLpkiRIsbHx8eUL1/evPPOO8Zut+eIMf7dv7HrX9uV8juT5WZs/3/gIiIiIvIPNKdJRERExAkqTSIiIiJOUGkSERERcYJKk4iIiIgTVJpEREREnKDSJCIiIuIElSYRERERJ6g0iYiIiDhBpUlEcowmTZrQs2fPHPfaIpI7eFgdQETEFXz11Vd4enpaHUNEXJhKk4gIULBgQasjiIiL0+45EXHK7NmzKVKkCHa7PdPjjz32GC+//DIAM2bMoEyZMnh5eVG+fHk++uijTOsmJCTQqVMnQkND8fHxoUqVKixcuBCAs2fP8vzzz3PPPffg5+dH1apV+fTTT/+SIz09nW7duhEYGEihQoUYPHgwzl5Cc/r06ZQrVw4fHx9CQ0N56qmnHMuu3z23evVqbDbbX24vvviiY/1vv/2WmjVr4uPjQ+nSpRk+fDjp6elO5RCRnElbmkTEKU8//TSvvvoqsbGxNG3aFIBz586xZMkSFi9ezNdff02PHj2YOHEikZGRLFy4kJdeeomiRYvy4IMPYrfbeeSRR7hw4QIff/wxZcqUYc+ePbi7uwOQnJxMrVq1GDBgAAEBASxatIh27dpRpkwZ6tat68gxb948OnTowObNm9myZQsdO3akePHiREdH/2P+LVu20L17dz766CPq16/PuXPnWLt27Q3XrV+/PidPnnTc37t3Ly1atKBRo0YArF27ln/9619MnjyZBx54gF9//ZWOHTsCMHTo0Kz/kEXEtRkRESc99thj5uWXX3bcnzVrlilSpIjJyMgw9evXN9HR0ZnWf/rpp02LFi2MMcYsXbrUuLm5mf379zv9fi1btjR9+vRx3G/cuLGpWLGisdvtjscGDBhgKlaseNPX+u9//2sCAgJMUlLSDZc3btzY9OjR4y+PnzlzxpQuXdp07drV8VjTpk3NqFGjMq330UcfmfDw8JvmEJGcS7vnRMRpbdu25b///S8pKSkAzJ8/n+eeew43Nzf27t1LgwYNMq3foEED9u7dC8D27dspWrQo99577w1fOyMjgzfffJOqVatSsGBB/P39Wbp0KUePHs203v3334/NZnPcj4iI4ODBg2RkZPxj9mbNmlGiRAlKly5Nu3btmD9/PpcvX/7H56SlpdGmTRtKlCjBpEmTHI//8ssvjBgxAn9/f8ctOjqakydP3vQ1RSTnUmkSEae1bt0aYwyLFi3i2LFjrF27lrZt2zr1XF9f339c/vbbbzNp0iQGDBhAbGws27dvJyoqitTU1OyITv78+dm2bRuffvop4eHhDBkyhOrVq5OQkPC3z+nSpQvHjh3jyy+/xMPjf7MZLl68yPDhw9m+fbvjtnPnTg4ePIiPj0+25BUR16M5TSLiNB8fH5588knmz5/PoUOHKF++PDVr1gSgYsWKrF+/nvbt2zvWX79+PZUqVQKgWrVq/PHHHxw4cOCGW5vWr1/PY489xgsvvACA3W7nwIEDjudfs2nTpkz3N27cSLly5Rxzo/6Jh4cHkZGRREZGMnToUIKCgli1ahVPPvnkX9adMGECX3zxBT/++CPBwcGZltWsWZP9+/dTtmzZm76niOQeKk0ickvatm1Lq1at2L17t6PgAPTr149nnnmGGjVqEBkZyYIFC/jqq69YsWIFAI0bN6ZRo0a0adOGCRMmULZsWfbt24fNZqN58+aUK1eO//znP/z4448UKFCACRMmcOrUqb+UpqNHj9K7d286derEtm3bmDJlCu+8885Ncy9cuJDDhw/TqFEjChQowOLFi7Hb7ZQvX/4v665YsYL+/fszbdo0ChUqRFxcHHB1a1lgYCBDhgyhVatWFC9enKeeego3Nzd++eUXdu3axVtvvXU7P14RcWVWT6oSkZwlIyPDhIeHG8D8+uuvmZZNnz7dlC5d2nh6epp7773XfPjhh5mWnz171rz00ksmODjY+Pj4mCpVqpiFCxc6lj322GPG39/fhISEmDfeeMP861//Mo899pjj+Y0bNzZdu3Y1nTt3NgEBAaZAgQLmtddeyzQx/O+sXbvWNG7c2BQoUMD4+vqaatWqmc8//zzTa1+bCD506FAD/OXWvn17x/pLliwx9evXN76+viYgIMDUrVvXzJ49+xZ/miKSk9iMcfIEJyIiIiJ5mCaCi4iIiDhBpUlEcoW1a9dmOgXAn28iIrdLu+dEJFe4cuUKx48f/9vlOtJNRG6XSpOIiIiIE7R7TkRERMQJKk0iIiIiTlBpEhEREXGCSpOIiIiIE1SaRERERJyg0iQiIiLiBJUmERERESeoNImIiIg44f8BsTpGBqZLV/MAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GRPO LOSS:\n",
      "   vocab_size    triton      torch\n",
      "0     32000.0  1.496177  17.489887\n",
      "1     64000.0  2.967113  33.858929\n",
      "2     96000.0  4.331829  47.953056\n",
      "3    128000.0  5.840693  63.874176\n",
      "4    160000.0  7.363397  79.620605\n",
      "5    192000.0  8.956139  95.290466\n"
     ]
    }
   ],
   "source": [
    "\n",
    "@triton.testing.perf_report(\n",
    "    triton.testing.Benchmark(\n",
    "        x_names=['vocab_size'],  # argument names to use as an x-axis for the plot\n",
    "        x_vals=[32000 + 16000 * i for i in range(0, 11, 2)],  # different possible values for `x_name`\n",
    "        line_arg='provider',  # argument name whose value corresponds to a different line in the plot\n",
    "        line_vals=['triton', \"torch\"],  # possible values for `line_arg``\n",
    "        line_names=[\n",
    "            \"triton\",\n",
    "            \"torch\",\n",
    "        ],  # label name for the lines\n",
    "        styles=[('blue', '-'), ('green', '-'), ('orange', '-')],  # line styles\n",
    "        ylabel=\"ms\",  # label name for the y-axis\n",
    "        plot_name=\"GRPO LOSS\",  # name for the plot. Used also as a file name for saving the plot.\n",
    "        args={'L': 2048, 'B': 16}\n",
    "    ))\n",
    "def benchmark(B, L, vocab_size, provider):\n",
    "    device = \"cuda\"\n",
    "    dtype = torch.bfloat16\n",
    "    logits = torch.randn(B, L + 1, vocab_size, device=device, dtype=dtype)\n",
    "    logits.requires_grad_(True)\n",
    "    completion_ids = torch.randint(0, vocab_size-1, (B, L), dtype=torch.int64, device=device)\n",
    "    completion_mask = torch.ones_like(completion_ids, dtype=torch.int32)\n",
    "    completion_mask[:, -200:] = 0\n",
    "    ref_logp = torch.randn(B, L, device=device, dtype=torch.float32)\n",
    "    old_logp = torch.randn(B, L, device=device, dtype=torch.float32)\n",
    "    advantages = torch.randn(B, device=device, dtype=torch.float32)\n",
    "    temperature, beta, eps_low, eps_high = 0.9, 0.1, 0.2, 0.4\n",
    "    # dy = torch.randn(B, L, device=device, dtype=torch.float32)\n",
    "    stream = torch.cuda.Stream()\n",
    "    torch.cuda.set_stream(stream)\n",
    "    if provider == 'torch':\n",
    "        ms = triton.testing.do_bench(lambda: torch_grpo_loss(logits,\n",
    "                old_logp,\n",
    "                ref_logp,\n",
    "                completion_ids,\n",
    "                advantages,\n",
    "                completion_mask,\n",
    "                temperature,\n",
    "                beta,\n",
    "                eps_low,\n",
    "                eps_high)[0].sum().backward(), grad_to_none=[logits])\n",
    "    if provider == 'triton':\n",
    "        ms = triton.testing.do_bench(lambda: triton_grpo_loss(logits,\n",
    "                old_logp,\n",
    "                ref_logp,\n",
    "                completion_ids,\n",
    "                advantages,\n",
    "                completion_mask,\n",
    "                temperature,\n",
    "                beta,\n",
    "                eps_low,\n",
    "                eps_high)[0].sum().backward(), grad_to_none=[logits])\n",
    "    return ms\n",
    "benchmark.run(show_plots=True, print_data=True)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 显存"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab_size = 151936\n",
    "B, L = 1, 2048\n",
    "device = \"cuda\"\n",
    "dtype = torch.bfloat16\n",
    "logits = torch.randn(B, L + 1, vocab_size, device=device, dtype=dtype)\n",
    "logits.requires_grad_(True)\n",
    "\n",
    "completion_ids = torch.randint(0, vocab_size-1, (B, L), dtype=torch.int64, device=device)\n",
    "completion_mask = None\n",
    "ref_logp = torch.randn(B, L, device=device, dtype=torch.float32)\n",
    "# ref_logp = None\n",
    "old_logp = torch.randn(B, L, device=device, dtype=torch.float32)\n",
    "# old_logp = None\n",
    "dy = torch.randn_like(ref_logp)\n",
    "advantages = torch.randn(B, device=device, dtype=torch.float32)\n",
    "temperature, beta, eps_low, eps_high = 0.9, 0.1, 0.5, 1.5\n",
    "time.sleep(5)\n",
    "\n",
    "for _ in tqdm(range(50)):\n",
    "    loss = triton_grpo_loss(logits,\n",
    "                old_logp,\n",
    "                ref_logp,\n",
    "                completion_ids,\n",
    "                advantages,\n",
    "                completion_mask,\n",
    "                temperature,\n",
    "                beta,\n",
    "                eps_low,\n",
    "                eps_high)[0]\n",
    "    loss.backward(dy)\n",
    "    logits.grad = None\n",
    "# 10 -> 10\n",
    "# 10 -> 56"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:01<00:00, 97.67it/s]\n"
     ]
    }
   ],
   "source": [
    "vocab_size = 151936\n",
    "B, L = 1, 2048\n",
    "hidden_size = 1536\n",
    "device = \"cuda\"\n",
    "dtype = torch.bfloat16\n",
    "\n",
    "_input = torch.randn(B, L+1, hidden_size, device=device, dtype=dtype).requires_grad_(True)\n",
    "fc = torch.nn.Linear(hidden_size, vocab_size, device=device, dtype=dtype)\n",
    "\n",
    "completion_ids = torch.randint(0, vocab_size-1, (B, L), dtype=torch.int64, device=device)\n",
    "completion_mask = None\n",
    "ref_logp = torch.randn(B, L, device=device, dtype=torch.float32)\n",
    "# ref_logp = None\n",
    "old_logp = torch.randn(B, L, device=device, dtype=torch.float32)\n",
    "# old_logp = None\n",
    "dy = torch.randn_like(ref_logp)\n",
    "advantages = torch.randn(B, device=device, dtype=torch.float32)\n",
    "temperature, beta, eps_low, eps_high = 0.9, 0.1, 0.5, 1.5\n",
    "time.sleep(1)\n",
    "\n",
    "for _ in tqdm(range(100)):\n",
    "    logits = fc(_input)\n",
    "    loss = triton_grpo_loss(logits,\n",
    "                old_logp,\n",
    "                ref_logp,\n",
    "                completion_ids,\n",
    "                advantages,\n",
    "                completion_mask,\n",
    "                temperature,\n",
    "                beta,\n",
    "                eps_low,\n",
    "                eps_high)[0]\n",
    "    loss.backward(dy)\n",
    "    \n",
    "# 10 -> 10\n",
    "# 10 -> 56"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss\n",
    "loss_fn = LigerFusedLinearGRPOLoss(compiled=True, \n",
    "                                   chunk_size=1,\n",
    "                                   temperature=0.9)\n",
    "\n",
    "vocab_size = 151936\n",
    "hidden_size = 1536\n",
    "B, L = 16, 2048\n",
    "device = \"cuda\"\n",
    "dtype = torch.bfloat16\n",
    "_input = torch.randn(B, L, hidden_size, device=device, dtype=dtype).requires_grad_(True)\n",
    "fc = torch.nn.Linear(hidden_size, vocab_size, device=device, dtype=dtype)\n",
    "lin_weight = fc.weight\n",
    "\n",
    "completion_ids = torch.randint(0, vocab_size-1, (B, L), dtype=torch.int64, device=device)\n",
    "completion_mask = torch.ones_like(completion_ids)\n",
    "ref_logp = torch.randn(B, L, device=device, dtype=torch.float32)\n",
    "# ref_logp = None\n",
    "old_logp = torch.randn(B, L, device=device, dtype=torch.float32)\n",
    "# old_logp = None\n",
    "dy = torch.randn_like(ref_logp)\n",
    "advantages = torch.randn(B, device=device, dtype=torch.float32)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [00:06<00:00,  7.51it/s]\n"
     ]
    }
   ],
   "source": [
    "for _ in tqdm(range(50)):\n",
    "    loss, _ = loss_fn(_input,\n",
    "                   lin_weight,\n",
    "                   completion_ids,\n",
    "                   completion_mask,\n",
    "                   advantages,\n",
    "                 ref_per_token_logps=ref_logp,\n",
    "                old_per_token_logps=old_logp,\n",
    "                   )\n",
    "    loss.backward()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logits = logits / temperature\n",
    "logits_softmax = F.softmax(logits, -1)\n",
    "logp = torch.gather(logits_softmax, -1, completion_ids.unsqueeze(-1)).squeeze(-1)\n",
    "logp.backward(old_logp)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 训练模拟"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.\n"
     ]
    }
   ],
   "source": [
    "def zero_grad(model:torch.nn.Module):\n",
    "    for p in model.parameters():\n",
    "        torch.optim.AdamW\n",
    "        p.grad = None\n",
    "def set_seed(seed):\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed(seed)\n",
    "set_seed(42)\n",
    "\n",
    "device = \"cuda\"\n",
    "dtype = torch.bfloat16\n",
    "\n",
    "B, L = 1, 2048\n",
    "vocab_size = 151936\n",
    "\n",
    "input_ids = torch.randint(0, vocab_size-1, (B, L+100), dtype=torch.int64, device=device)\n",
    "completion_ids = input_ids[:, -L:].contiguous()\n",
    "advantages = torch.randn(B, device=device, dtype=torch.float32)\n",
    "ref_logp = torch.randn(B, L, device=device, dtype=torch.float32)\n",
    "old_logp = torch.randn(B, L, device=device, dtype=torch.float32)\n",
    "completion_mask = torch.ones_like(completion_ids, dtype=torch.int32)\n",
    "completion_mask[:, -200:] = 0\n",
    "temperature, beta, eps_low, eps_high = 0.9, 0.1, 0.2, 0.4\n",
    "\n",
    "model_path = \"/sharedata/mdy/models/DeepSeek-R1-Distill-Qwen-1.5B\"\n",
    "model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map=\"cuda\")\n",
    "# config = AutoConfig.from_pretrained(model_path)\n",
    "# config._attn_implementation = \"flash_attention_2\"\n",
    "# config.num_hidden_layers = 2\n",
    "# model = Qwen2ForCausalLM(config).to(dtype).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 22%|██▏       | 11/50 [00:05<00:17,  2.19it/s]"
     ]
    }
   ],
   "source": [
    "model.train()\n",
    "model._set_gradient_checkpointing(False)\n",
    "zero_grad(model)\n",
    "for _ in tqdm(range(50)):\n",
    "    # torch.cuda.empty_cache()\n",
    "    # zero_grad(model)\n",
    "    \n",
    "    logits = model(input_ids, logits_to_keep=L+1).logits\n",
    "    per_token_loss = triton_grpo_loss(logits,\n",
    "                    old_logp,\n",
    "                    ref_logp,\n",
    "                    completion_ids,\n",
    "                    advantages,\n",
    "                    completion_mask,\n",
    "                    temperature,\n",
    "                    beta,\n",
    "                    eps_low,\n",
    "                    eps_high)[0]\n",
    "    loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()\n",
    "    loss.backward()\n",
    "print(list(model.parameters())[-5].grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [01:01<00:00,  1.23s/it]\n"
     ]
    }
   ],
   "source": [
    "STEP = 1\n",
    "model.train()\n",
    "model._set_gradient_checkpointing(False)\n",
    "zero_grad(model)\n",
    "for _ in tqdm(range(50)):\n",
    "    total_tokens_this_mbs = completion_mask.sum()\n",
    "    bs = input_ids.size(0)\n",
    "    for idx in range(0, bs, STEP):\n",
    "        logits = model(input_ids=input_ids[idx:idx+STEP], \n",
    "                       attention_mask=None, \n",
    "                       logits_to_keep=L+ 1).logits\n",
    "        per_token_loss, per_token_kl, is_clipped = triton_grpo_loss(logits, \n",
    "                                                                    old_logp[idx:idx+STEP],\n",
    "                                                                    ref_logp[idx:idx+STEP],\n",
    "                                                                    completion_ids[idx:idx+STEP],\n",
    "                                                                    advantages[idx:idx+STEP],\n",
    "                                                                    completion_mask[idx:idx+STEP],\n",
    "                                                                    temperature,\n",
    "                                                                    beta,\n",
    "                                                                    eps_low,\n",
    "                                                                    eps_high)\n",
    "        loss = (per_token_loss * completion_mask[idx:idx+STEP]).sum() / total_tokens_this_mbs\n",
    "        loss.backward()\n",
    "# 16.36  61\n",
    "# 23.49  52   \n",
    "# 38.18  48\n",
    "# 69.87  45\n",
    "# 128    44\n",
    "# print(list(model.parameters())[-5].grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [00:08<00:00,  5.98it/s]\n"
     ]
    }
   ],
   "source": [
    "from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss\n",
    "loss_fn = LigerFusedLinearGRPOLoss(compiled=True, \n",
    "                                   chunk_size=1,\n",
    "                                   temperature=0.9)\n",
    "model.train()\n",
    "model._set_gradient_checkpointing(False)\n",
    "zero_grad(model)\n",
    "for _ in tqdm(range(50)):\n",
    "    _input = model.model(input_ids=input_ids,\n",
    "                        attention_mask=None, ).last_hidden_state[:, -L:]\n",
    "        \n",
    "    loss, _ = loss_fn(_input,\n",
    "                    model.lm_head.weight,\n",
    "                   completion_ids,\n",
    "                   completion_mask,\n",
    "                   advantages,\n",
    "                   ref_per_token_logps=ref_logp,\n",
    "                   old_per_token_logps=old_logp,\n",
    "                   )\n",
    "    loss.backward()\n",
    "# 17.4 \n",
    "# 23.41\n",
    "# 35.74\n",
    "# 63.98\n",
    "# 112.7 47 - 2(compile time)\n",
    "\n",
    "# print(list(model.parameters())[-5].grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>global_bs</th>\n",
       "      <th>micro_bs</th>\n",
       "      <th>micro_micro_bs</th>\n",
       "      <th>seq_len</th>\n",
       "      <th>loss</th>\n",
       "      <th>time(s)</th>\n",
       "      <th>sample/s</th>\n",
       "      <th>memory(G)</th>\n",
       "      <th>speed up</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>1024</td>\n",
       "      <td>16</td>\n",
       "      <td>1</td>\n",
       "      <td>2048</td>\n",
       "      <td>my</td>\n",
       "      <td>77.736639</td>\n",
       "      <td>13.17</td>\n",
       "      <td>15.26</td>\n",
       "      <td>1.00 x</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1024</td>\n",
       "      <td>16</td>\n",
       "      <td>1</td>\n",
       "      <td>2048</td>\n",
       "      <td>liger</td>\n",
       "      <td>81.901334</td>\n",
       "      <td>12.50</td>\n",
       "      <td>16.38</td>\n",
       "      <td>0.95 x</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>1024</td>\n",
       "      <td>16</td>\n",
       "      <td>2</td>\n",
       "      <td>2048</td>\n",
       "      <td>my</td>\n",
       "      <td>66.715685</td>\n",
       "      <td>15.35</td>\n",
       "      <td>22.31</td>\n",
       "      <td>1.17 x</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1024</td>\n",
       "      <td>16</td>\n",
       "      <td>2</td>\n",
       "      <td>2048</td>\n",
       "      <td>liger</td>\n",
       "      <td>69.758016</td>\n",
       "      <td>14.68</td>\n",
       "      <td>22.86</td>\n",
       "      <td>1.11 x</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>1024</td>\n",
       "      <td>16</td>\n",
       "      <td>4</td>\n",
       "      <td>2048</td>\n",
       "      <td>my</td>\n",
       "      <td>61.527914</td>\n",
       "      <td>16.64</td>\n",
       "      <td>35.95</td>\n",
       "      <td>1.26 x</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1024</td>\n",
       "      <td>16</td>\n",
       "      <td>4</td>\n",
       "      <td>2048</td>\n",
       "      <td>liger</td>\n",
       "      <td>64.330899</td>\n",
       "      <td>15.92</td>\n",
       "      <td>36.37</td>\n",
       "      <td>1.21 x</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>1024</td>\n",
       "      <td>16</td>\n",
       "      <td>8</td>\n",
       "      <td>2048</td>\n",
       "      <td>my</td>\n",
       "      <td>58.667135</td>\n",
       "      <td>17.45</td>\n",
       "      <td>66.08</td>\n",
       "      <td>1.32 x</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1024</td>\n",
       "      <td>16</td>\n",
       "      <td>8</td>\n",
       "      <td>2048</td>\n",
       "      <td>liger</td>\n",
       "      <td>61.370523</td>\n",
       "      <td>16.69</td>\n",
       "      <td>63.43</td>\n",
       "      <td>1.27 x</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>1024</td>\n",
       "      <td>16</td>\n",
       "      <td>16</td>\n",
       "      <td>2048</td>\n",
       "      <td>my</td>\n",
       "      <td>57.013902</td>\n",
       "      <td>17.96</td>\n",
       "      <td>126.61</td>\n",
       "      <td>1.36 x</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1024</td>\n",
       "      <td>16</td>\n",
       "      <td>16</td>\n",
       "      <td>2048</td>\n",
       "      <td>liger</td>\n",
       "      <td>60.200599</td>\n",
       "      <td>17.01</td>\n",
       "      <td>112.15</td>\n",
       "      <td>1.29 x</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    global_bs  micro_bs  micro_micro_bs  seq_len   loss    time(s)  sample/s  \\\n",
       "12       1024        16               1     2048     my  77.736639     13.17   \n",
       "0        1024        16               1     2048  liger  81.901334     12.50   \n",
       "11       1024        16               2     2048     my  66.715685     15.35   \n",
       "1        1024        16               2     2048  liger  69.758016     14.68   \n",
       "10       1024        16               4     2048     my  61.527914     16.64   \n",
       "2        1024        16               4     2048  liger  64.330899     15.92   \n",
       "9        1024        16               8     2048     my  58.667135     17.45   \n",
       "3        1024        16               8     2048  liger  61.370523     16.69   \n",
       "13       1024        16              16     2048     my  57.013902     17.96   \n",
       "4        1024        16              16     2048  liger  60.200599     17.01   \n",
       "\n",
       "    memory(G) speed up  \n",
       "12      15.26   1.00 x  \n",
       "0       16.38   0.95 x  \n",
       "11      22.31   1.17 x  \n",
       "1       22.86   1.11 x  \n",
       "10      35.95   1.26 x  \n",
       "2       36.37   1.21 x  \n",
       "9       66.08   1.32 x  \n",
       "3       63.43   1.27 x  \n",
       "13     126.61   1.36 x  \n",
       "4      112.15   1.29 x  "
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "df = pd.read_json('./infos.jsonl', lines=True)\n",
    "df = df[df[\"seq_len\"] == 2048]\n",
    "df = df.sort_values([\"micro_micro_bs\", \"loss\"], ascending=[True, False])\n",
    "df['speed up'] = df[\"sample/s\"].apply(lambda x: f'{(x / df[\"sample/s\"].iloc[0]):.2f} x')\n",
    "df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# grad of clamp op"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _Clamp(torch.autograd.Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, coef_1, advantages, low, high):\n",
    "        coef_2 = copy.deepcopy(coef_1)\n",
    "        coef_2[coef_1<low] = low\n",
    "        coef_2[coef_1>high] = high\n",
    "        per_token_loss1 = coef_1 * advantages.unsqueeze(1)\n",
    "        per_token_loss2 = coef_2 * advantages.unsqueeze(1)\n",
    "        per_token_loss = -torch.min(per_token_loss1, per_token_loss2)\n",
    "        # min操作选取的不是coef_2被截断的部分，梯度是正常的，其他为0\n",
    "        mask = per_token_loss1 <= per_token_loss2\n",
    "        ctx.save_for_backward(mask, advantages)\n",
    "        return per_token_loss\n",
    "    \n",
    "    @staticmethod\n",
    "    def backward(ctx, dloss:torch.Tensor):\n",
    "        mask, advantages = ctx.saved_tensors\n",
    "        dgrad = -dloss * advantages.unsqueeze(1) * mask\n",
    "        return dgrad, None, None, None\n",
    "\n",
    "def my_clamp(coef_1, advantages, low, high):\n",
    "    return _Clamp.apply(coef_1, advantages, low, high)\n",
    "\n",
    "def offical_clamp(coef_1, advantages, low, high):\n",
    "    coef_2 = torch.clamp(coef_1, low, high)\n",
    "    per_token_loss1 = coef_1 * advantages.unsqueeze(1)\n",
    "    per_token_loss2 = coef_2 * advantages.unsqueeze(1)\n",
    "    per_token_loss = -torch.min(per_token_loss1, per_token_loss2)\n",
    "    return per_token_loss\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "coef_1 = torch.randn(64,64)\n",
    "coef_1.requires_grad_(True)\n",
    "coef_1_copy = copy.deepcopy(coef_1)\n",
    "dy = torch.randn_like(coef_1)\n",
    "low = -1\n",
    "high = 1\n",
    "advantages = torch.randn(64)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "loss1 = offical_clamp(coef_1, advantages, low, high)\n",
    "loss2 = my_clamp(coef_1_copy, advantages, low, high)\n",
    "loss1.backward(dy)\n",
    "loss2.backward(dy)\n",
    "print(torch.allclose(loss1, loss2))\n",
    "print(torch.allclose(coef_1.grad, coef_1_copy.grad))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(1.0000)\n",
      "tensor(0.)\n"
     ]
    }
   ],
   "source": [
    "print(torch.tensor(10).sigmoid())\n",
    "print(torch.tensor(-100).sigmoid())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# decouple logp and loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from triton_grpo_loss.decouple_logp_and_loss import fused_selective_log_softmax, compile_grpo_loss, triton_grpo_loss as triton_grpo_loss2\n",
    "# 对原始函数做了简单修改\n",
    "def selective_log_softmax(logits, input_ids, temperature=0.9):\n",
    "    logits = logits[:, :-1, :]  # (B, L-1, V), exclude the last logit: it corresponds to the next token pred\n",
    "    logits_to_keep = logits.size(1)\n",
    "    index = input_ids[:, -logits_to_keep:]\n",
    "    logits = logits[:, -logits_to_keep:]\n",
    "    logits = logits / temperature\n",
    "\n",
    "    if logits.dtype in [torch.float32, torch.float64]:\n",
    "        selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)\n",
    "        # loop to reduce peak mem consumption\n",
    "        logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])\n",
    "        per_token_logps = selected_logits - logsumexp_values  # log_softmax(x_i) = x_i - logsumexp(x)\n",
    "    else:\n",
    "        # logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach\n",
    "        per_token_logps = []\n",
    "        for row_logits, row_labels in zip(logits, index):  # loop to reduce peak mem consumption\n",
    "            row_logps = F.log_softmax(row_logits, dim=-1)\n",
    "            row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)\n",
    "            per_token_logps.append(row_per_token_logps)\n",
    "        per_token_logps = torch.stack(per_token_logps)\n",
    "    return per_token_logps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_seed(seed):\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed(seed)\n",
    "set_seed(40)\n",
    "\n",
    "vocab_size = 128000\n",
    "B, L = 8, 1024\n",
    "device = \"cuda\"\n",
    "dtype = torch.bfloat16\n",
    "logits1 = torch.randn(B, L + 1, vocab_size, device=device, dtype=dtype)\n",
    "logits1.requires_grad_(True)\n",
    "logits2 = deepcopy(logits1)\n",
    "gold_logits = logits1.detach().clone().float()\n",
    "gold_logits.requires_grad_(True)\n",
    "\n",
    "\n",
    "\n",
    "completion_ids = torch.randint(0, vocab_size-1, (B, L), dtype=torch.int64, device=device)\n",
    "completion_mask = torch.ones_like(completion_ids, dtype=torch.int32)\n",
    "# completion_mask[:, -200:] = 0\n",
    "completion_mask = None\n",
    "ref_logp = torch.randn(B, L, device=device, dtype=torch.float32)\n",
    "# ref_logp = None\n",
    "old_logp = torch.randn(B, L, device=device, dtype=torch.float32)\n",
    "# old_logp = None\n",
    "advantages = torch.randn(B, device=device, dtype=torch.float32)\n",
    "temperature, beta, eps_low, eps_high = 0.9, 0.2, 0.2, 0.4\n",
    "inplace = True\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "最大差异: 0.004827782977372408, 平均差异: 0.0012687613489106297\n",
      "最大差异: 1.8042213412172714e-07, 平均差异: 1.104520030992262e-08\n",
      "最大差异: 0.05038011074066162, 平均差异: 0.004313413519412279\n",
      "最大差异: 0.003868201980367303, 平均差异: 0.00038550799945369363\n"
     ]
    }
   ],
   "source": [
    "logp1 = selective_log_softmax(logits1, completion_ids, temperature)\n",
    "logp2 = fused_selective_log_softmax(logits2, completion_ids, temperature, completion_mask)\n",
    "gold_logp = selective_log_softmax(gold_logits, completion_ids, temperature)\n",
    "dy = torch.randn_like(gold_logp)\n",
    "logp1.backward(dy)\n",
    "logp2.backward(dy)\n",
    "gold_logp.backward(dy)\n",
    "compare(logp1, gold_logp)\n",
    "compare(logp2, gold_logp)\n",
    "compare(logits1.grad, gold_logits.grad)\n",
    "compare(logits2.grad, gold_logits.grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3.054377317428589\n",
      "21.552736282348633\n"
     ]
    }
   ],
   "source": [
    "print(triton.testing.do_bench(lambda: fused_selective_log_softmax(logits2, completion_ids, temperature, completion_mask).backward(dy)))\n",
    "print(triton.testing.do_bench(lambda: selective_log_softmax(logits1, completion_ids, temperature).backward(dy)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "最大差异: 0.0, 平均差异: 0.0\n",
      "最大差异: 0.0, 平均差异: 0.0\n",
      "最大差异: 0.0, 平均差异: 0.0\n",
      "最大差异: 2.9119386454112828e-05, 平均差异: 8.716664012808906e-08\n"
     ]
    }
   ],
   "source": [
    "logp1 = torch.randn(B, L, device=device, dtype=torch.float32).requires_grad_(True)\n",
    "logp2 = deepcopy(logp1).requires_grad_(True)\n",
    "loss1, kl1, is_clipped1 = compile_grpo_loss(logp1, advantages, old_logp, ref_logp, beta, eps_low, eps_high)\n",
    "loss2, kl2, is_clipped2 = triton_grpo_loss(logp2, advantages, old_logp, ref_logp, beta, eps_low, eps_high)\n",
    "dy = torch.randn_like(loss2)\n",
    "loss1.backward(dy)\n",
    "loss2.backward(dy)\n",
    "compare(loss1, loss2)\n",
    "compare(kl1, kl2)\n",
    "compare(is_clipped1, is_clipped2)\n",
    "compare(logp1.grad, logp2.grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.2072295844554901\n",
      "0.1999889463186264\n"
     ]
    }
   ],
   "source": [
    "print(triton.testing.do_bench(lambda: compile_grpo_loss(logp1, advantages, old_logp, ref_logp, beta, eps_low, eps_high)[0].backward(dy)))\n",
    "print(triton.testing.do_bench(lambda: triton_grpo_loss2(logp2, advantages, old_logp, ref_logp, beta, eps_low, eps_high)[0].backward(dy)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_seed(seed):\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed(seed)\n",
    "set_seed(40)\n",
    "\n",
    "vocab_size = 128000\n",
    "B, L = 8, 1024\n",
    "device = \"cuda\"\n",
    "dtype = torch.bfloat16\n",
    "logits1 = torch.randn(B, L + 1, vocab_size, device=device, dtype=dtype)\n",
    "logits1.requires_grad_(True)\n",
    "logits2 = deepcopy(logits1)\n",
    "\n",
    "completion_ids = torch.randint(0, vocab_size-1, (B, L), dtype=torch.int64, device=device)\n",
    "completion_mask = torch.ones_like(completion_ids, dtype=torch.int32)\n",
    "# completion_mask[:, -200:] = 0\n",
    "completion_mask = None\n",
    "ref_logp = torch.randn(B, L, device=device, dtype=torch.float32)\n",
    "# ref_logp = None\n",
    "old_logp = torch.randn(B, L, device=device, dtype=torch.float32)\n",
    "# old_logp = None\n",
    "advantages = torch.randn(B, device=device, dtype=torch.float32)\n",
    "temperature, beta, eps_low, eps_high = 0.9, 0.2, 0.2, 0.4\n",
    "inplace = True\n",
    "\n",
    "def combine_func(logits, \n",
    "                     old_logp, \n",
    "                     ref_logp, \n",
    "                     completion_ids, \n",
    "                     advantages, \n",
    "                     completion_mask=None, \n",
    "                     temperature=0.9, \n",
    "                     beta=0.04, \n",
    "                     eps_low=0.2, \n",
    "                     eps_high=0.4, \n",
    "                     inplace=True):\n",
    "    assert logits is not None and completion_ids is not None and advantages is not None, \"must provide logits、completion_ids and advantages\"\n",
    "    logp = fused_selective_log_softmax(logits, completion_ids, temperature, completion_mask)\n",
    "    return triton_grpo_loss2(logp, advantages, old_logp, ref_logp, beta, eps_low, eps_high)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "最大差异: 0.0, 平均差异: 0.0\n",
      "最大差异: 0.0, 平均差异: 0.0\n",
      "最大差异: 0.0, 平均差异: 0.0\n"
     ]
    }
   ],
   "source": [
    "loss1, kl1, is_clipped1 = triton_grpo_loss(logits1,\n",
    "                old_logp,\n",
    "                ref_logp,\n",
    "                completion_ids,\n",
    "                advantages,\n",
    "                completion_mask,\n",
    "                temperature,\n",
    "                beta,\n",
    "                eps_low,\n",
    "                eps_high)\n",
    "\n",
    "loss2, kl2, is_clipped2 = combine_func(logits2,\n",
    "                old_logp,\n",
    "                ref_logp,\n",
    "                completion_ids,\n",
    "                advantages,\n",
    "                completion_mask,\n",
    "                temperature,\n",
    "                beta,\n",
    "                eps_low,\n",
    "                eps_high,\n",
    "                )\n",
    "compare(loss1, loss2)\n",
    "compare(kl1, kl2)\n",
    "dy = torch.randn_like(loss1)\n",
    "loss1.backward(dy)\n",
    "loss2.backward(dy)\n",
    "compare(logits1.grad, logits2.grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3.04909086227417\n",
      "3.0818052291870117\n"
     ]
    }
   ],
   "source": [
    "print(triton.testing.do_bench(lambda: triton_grpo_loss(logits1,\n",
    "                old_logp,\n",
    "                ref_logp,\n",
    "                completion_ids,\n",
    "                advantages,\n",
    "                completion_mask,\n",
    "                temperature,\n",
    "                beta,\n",
    "                eps_low,\n",
    "                eps_high)[0].backward(dy)))\n",
    "print(triton.testing.do_bench(lambda: combine_func(logits2,\n",
    "                old_logp,\n",
    "                ref_logp,\n",
    "                completion_ids,\n",
    "                advantages,\n",
    "                completion_mask,\n",
    "                temperature,\n",
    "                beta,\n",
    "                eps_low,\n",
    "                eps_high,\n",
    "                )[0].backward(dy)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
